快速完成多模态模型-CLIP的训练
当前多模态作为很火的领域,使用多模态可以编很多故事,那么如何训练一个多模态模型就是一个很重要的事,下面讲一下快速上手针对自己的数据集进行训练多模态模型。
数据集准备
以Clip为主的多模态模型其训练数据主要是图像-文本对,图像单独存放到一个路径,文本-图像路径存放到一个json文件中,在使用时读取相应json文件就行。
json文件结构如下
[
{
'image' : '/image/0213123.jpg',
'caption': 'a man with red hat'
},
....
{
'image' : '/image/7213123.jpg',
'caption': 'a jk woman with glass in class'
}
]
数据集加载
需要什么加载什么的原则
class MyDataset(torch.utils.data.Dataset):
def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
super().__init__()
self.tokenizer = tokenizer
self.size = size
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.image_root_path = image_root_path
self.imageface_root_path = '/home/ddwgroup/workplace/hq1M_face'
self.data = json.load(open(json_file))[:500000]
self.transform = transforms.Compose([
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.size),
transforms.ToTensor(),
])
def __getitem__(self, idx):
item = self.data[idx]
text = item["image_short_caption"][0]
image_file = item["image"]
# read image
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
image = self.transform(raw_image.convert("RGB"))
return {
"image": image,
"text": text,
}
def __len__(self):
return len(self.data)
def collate_fn(data):
images = torch.stack([example["image"] for example in data])
text = [example["text"] for example in data]
return {
"images": images,
"text": text,
}
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# dataloader
train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
模型架构搭建
目前快速构建通常使用transformers库来完成
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
CLIP训练代码
import time
import torch
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from datasets import load_dataset
import yaml
import pdb
import argparse
import random
from torchvision import transforms
import json
from tqdm import tqdm
import torch.nn.functional as F
from accelerate import Accelerator
from PIL import Image
import os
# 设置环境变量
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default='runwayml/stable-diffusion-v1-5',
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument(
"--data_json_file",
type=str,
default='HQ_FaceCaption_V1_1.json',
help="Training data",
)
parser.add_argument(
"--data_root_path",
type=str,
default='/home/ddwgroup/workplace/hq1M/',
help="Training data root path",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
# required=True,
help="Path to CLIP image encoder",
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-ip_adapter-face",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--resolution",
type=int,
default=224,
help=(
"The resolution for input images"
),
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate to use.",
)
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--train_batch_size", type=int, default=24, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--save_steps",
type=int,
default=1000,
help=(
"Save a checkpoint of the training state every X updates"
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
# Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
super().__init__()
self.tokenizer = tokenizer
self.size = size
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.image_root_path = image_root_path
self.imageface_root_path = '/home/ddwgroup/workplace/hq1M_face'
self.data = json.load(open(json_file))[:500000]
self.transform = transforms.Compose([
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.size),
transforms.ToTensor(),
])
def __getitem__(self, idx):
item = self.data[idx]
text = item["image_short_caption"][0]
image_file = item["image"]
# read image
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
image = self.transform(raw_image.convert("RGB"))
return {
"image": image,
"text": text,
}
def __len__(self):
return len(self.data)
def collate_fn(data):
images = torch.stack([example["image"] for example in data])
text = [example["text"] for example in data]
return {
"images": images,
"text": text,
}
args = parse_args()
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# dataloader
train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# 环境设置
import os
accelerator = Accelerator()
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
model, dataloader, optimizer = accelerator.prepare(model, train_dataloader, optimizer)
for epoch in range(args.num_train_epochs):
model.train()
total_loss = 0
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_train_epochs}")
for batch in progress_bar:
start_time = time.time()
optimizer.zero_grad()
inputs = processor(text=batch['text'], images=batch['images'], return_tensors="pt", padding=True, do_rescale=False).to(accelerator.device)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
logits_per_text = outputs.logits_per_text
ground_truth = torch.arange(len(logits_per_image), device=logits_per_image.device)
loss_img = F.cross_entropy(logits_per_image, ground_truth)
loss_txt = F.cross_entropy(logits_per_text, ground_truth)
loss = (loss_img + loss_txt) / 2
accelerator.backward(loss)
optimizer.step()
total_loss += loss.item()
end_time = time.time()
time_per_iter = end_time - start_time
progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "time_per_iter": f"{time_per_iter:.2f}s"})
print(f"Epoch [{epoch+1}/{args.num_train_epochs}] completed with average loss {total_loss/len(dataloader):.4f}")