快速完成多模态模型-CLIP的训练

当前多模态作为很火的领域,使用多模态可以编很多故事,那么如何训练一个多模态模型就是一个很重要的事,下面讲一下快速上手针对自己的数据集进行训练多模态模型。

数据集准备

以Clip为主的多模态模型其训练数据主要是图像-文本对,图像单独存放到一个路径,文本-图像路径存放到一个json文件中,在使用时读取相应json文件就行。

json文件结构如下

[
    {
    	'image' : '/image/0213123.jpg',
        'caption': 'a man with red hat'
	},
    ....
    {
    	'image' : '/image/7213123.jpg',
        'caption': 'a jk woman with glass in class'
	}
]

数据集加载

需要什么加载什么的原则

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
        super().__init__()

        self.tokenizer = tokenizer
        self.size = size
        self.i_drop_rate = i_drop_rate
        self.t_drop_rate = t_drop_rate
        self.ti_drop_rate = ti_drop_rate
        self.image_root_path = image_root_path
        self.imageface_root_path = '/home/ddwgroup/workplace/hq1M_face'

        self.data = json.load(open(json_file))[:500000]
        self.transform = transforms.Compose([
            transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.size),
            transforms.ToTensor(),
        ])

    def __getitem__(self, idx):
        item = self.data[idx] 
        text = item["image_short_caption"][0]
        image_file = item["image"]
        
        # read image
        raw_image = Image.open(os.path.join(self.image_root_path, image_file))
        image = self.transform(raw_image.convert("RGB"))

        return {
            "image": image,
            "text": text,

        }

    def __len__(self):
        return len(self.data)
    

def collate_fn(data):
    images = torch.stack([example["image"] for example in data])
    text = [example["text"] for example in data]
    

    return {
        "images": images,
        "text": text,
    }

tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# dataloader
train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
)

模型架构搭建

目前快速构建通常使用transformers库来完成

from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer

tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

CLIP训练代码

import time
import torch
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from datasets import load_dataset
import yaml
import pdb 
import argparse
import random
from torchvision import transforms
import json
from tqdm import tqdm
import torch.nn.functional as F
from accelerate import Accelerator
from PIL import Image
import os
# 设置环境变量
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default='runwayml/stable-diffusion-v1-5',
        
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--pretrained_ip_adapter_path",
        type=str,
        default=None,
        help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
    )
    parser.add_argument(
        "--data_json_file",
        type=str,
        default='HQ_FaceCaption_V1_1.json',
      
        help="Training data",
    )
    parser.add_argument(
        "--data_root_path",
        type=str,
        default='/home/ddwgroup/workplace/hq1M/',
       
        help="Training data root path",
    )
    parser.add_argument(
        "--image_encoder_path",
        type=str,
        default='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
        # required=True,
        help="Path to CLIP image encoder",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="sd-ip_adapter-face",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=224,
        help=(
            "The resolution for input images"
        ),
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Learning rate to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=100)
    parser.add_argument(
        "--train_batch_size", type=int, default=24, help="Batch size (per device) for the training dataloader."
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
        ),
    )
    parser.add_argument(
        "--save_steps",
        type=int,
        default=1000,
        help=(
            "Save a checkpoint of the training state every X updates"
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="fp16",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args

# Dataset
class MyDataset(torch.utils.data.Dataset):

    def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
        super().__init__()

        self.tokenizer = tokenizer
        self.size = size
        self.i_drop_rate = i_drop_rate
        self.t_drop_rate = t_drop_rate
        self.ti_drop_rate = ti_drop_rate
        self.image_root_path = image_root_path
        self.imageface_root_path = '/home/ddwgroup/workplace/hq1M_face'

        self.data = json.load(open(json_file))[:500000]
        self.transform = transforms.Compose([
            transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.size),
            transforms.ToTensor(),
        ])

    def __getitem__(self, idx):
        item = self.data[idx] 
        text = item["image_short_caption"][0]
        image_file = item["image"]
        
        # read image
        raw_image = Image.open(os.path.join(self.image_root_path, image_file))
        image = self.transform(raw_image.convert("RGB"))

        return {
            "image": image,
            "text": text,

        }

    def __len__(self):
        return len(self.data)
    

def collate_fn(data):
    images = torch.stack([example["image"] for example in data])
    text = [example["text"] for example in data]
    

    return {
        "images": images,
        "text": text,
    }

args = parse_args()


tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# dataloader
train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
)
    
# 环境设置


import os

accelerator = Accelerator()

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
model, dataloader, optimizer = accelerator.prepare(model, train_dataloader, optimizer)


for epoch in range(args.num_train_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_train_epochs}")
    for batch in progress_bar:
        start_time = time.time()
        optimizer.zero_grad()
        inputs = processor(text=batch['text'], images=batch['images'], return_tensors="pt", padding=True, do_rescale=False).to(accelerator.device)
        outputs = model(**inputs)

        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text
            
        ground_truth = torch.arange(len(logits_per_image), device=logits_per_image.device)
        loss_img = F.cross_entropy(logits_per_image, ground_truth)
        loss_txt = F.cross_entropy(logits_per_text, ground_truth)
        loss = (loss_img + loss_txt) / 2

        accelerator.backward(loss)
        optimizer.step()


        total_loss += loss.item()
        end_time = time.time()
        time_per_iter = end_time - start_time

        progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "time_per_iter": f"{time_per_iter:.2f}s"})

    print(f"Epoch [{epoch+1}/{args.num_train_epochs}] completed with average loss {total_loss/len(dataloader):.4f}")