加载数据

在线加载

from datasets import load_dataset
dataset = load_dataset("imdb")

加载本地文件

dataset = load_dataset("csv", data_files="my_data.csv")
 
split_dataset = load_dataset("text", data_files={
    'train': "datasets/TinyStories-train.txt",
    'val': "datasets/TinyStories-valid.txt"
})

数据集结构

返回对象为 DatasetDict,包含 train, test 等切片

print(dataset)
# 输出示例:
# DatasetDict({
# 	train: Dataset({
# 		features: ['text', 'label'],
# 		num_rows: 25000
# 	})
# })

数据处理

map 函数

对数据集中的每一行应用一个自定义函数。

举个例子,给所有记录添加一个前缀:

def add_prefix(example):
    return {"text": "Prefix: " + example["text"]}
 
updated_dataset = dataset.map(add_prefix)

map 函数有几个重要参数:

batched

默认情况下,map 是逐行处理的,开启 batched=True 后,函数会一次性接收一个列表,这可以减少函数调用次数

def batch_tokenize(examples):
    return tokenizer(examples["text"], padding=True)
 
dataset.map(batch_tokenize, batched=True, batch_size=1000)

num_proc

多进程并行,利用多核 GPU 可以提升性能。num_proc=N 表示开启 N 个进程处理。

remove_columns

处理完数据后,常常会把原始数据删除,此时就需要把该列删除以节省内存,比如 tokenize 后不再需要原始文本:

tokenized = dataset.map(process_func, batched=True, remove_columns=["text"])

其他常用转换函数

  • filter: 筛选符合条件的数据
  • shuffle: 打乱数据
  • remove_columns: 删除列
  • rename_column: 重命名列

Dataset

shard

将数据集切分为多个均匀的小块。

dataset.shard(num_shards=1024, index=0, contiguous=True)
  • contiguous: 是否取连续的块

    • True: 连续的一段
    • False (默认): 通过取模运算取样
  • 解决内存溢出:数据集过大时,不能一次完成处理,可以循环 shard 每次处理一份,处理后释放

  • 分布式训练:做数据并行,每个 GPU 只看数据集其中的几份