加载数据
在线加载
from datasets import load_dataset
dataset = load_dataset("imdb")加载本地文件
dataset = load_dataset("csv", data_files="my_data.csv")
split_dataset = load_dataset("text", data_files={
'train': "datasets/TinyStories-train.txt",
'val': "datasets/TinyStories-valid.txt"
})数据集结构
返回对象为 DatasetDict,包含 train, test 等切片
print(dataset)
# 输出示例:
# DatasetDict({
# train: Dataset({
# features: ['text', 'label'],
# num_rows: 25000
# })
# })数据处理
map 函数
对数据集中的每一行应用一个自定义函数。
举个例子,给所有记录添加一个前缀:
def add_prefix(example):
return {"text": "Prefix: " + example["text"]}
updated_dataset = dataset.map(add_prefix)map 函数有几个重要参数:
batched
默认情况下,map 是逐行处理的,开启 batched=True 后,函数会一次性接收一个列表,这可以减少函数调用次数。
def batch_tokenize(examples):
return tokenizer(examples["text"], padding=True)
dataset.map(batch_tokenize, batched=True, batch_size=1000)num_proc
多进程并行,利用多核 GPU 可以提升性能。num_proc=N 表示开启 N 个进程处理。
remove_columns
处理完数据后,常常会把原始数据删除,此时就需要把该列删除以节省内存,比如 tokenize 后不再需要原始文本:
tokenized = dataset.map(process_func, batched=True, remove_columns=["text"])其他常用转换函数
filter: 筛选符合条件的数据shuffle: 打乱数据remove_columns: 删除列rename_column: 重命名列
Dataset
shard
将数据集切分为多个均匀的小块。
dataset.shard(num_shards=1024, index=0, contiguous=True)-
contiguous: 是否取连续的块True: 连续的一段False(默认): 通过取模运算取样
-
解决内存溢出:数据集过大时,不能一次完成处理,可以循环 shard 每次处理一份,处理后释放
-
分布式训练:做数据并行,每个 GPU 只看数据集其中的几份