畫板探秘系列:結合 Transformers.js 實現影像智慧處理

LH_S發表於2024-11-23

前言

我目前在維護一款功能強大的開源創意畫板。這個畫板整合了很多有意思的畫筆和輔助繪畫功能,可以讓使用者體驗到全新的繪畫效果。無論是在移動端還是PC端,都能享受到較好的互動體驗和效果展示。

在本文中,我將詳細講解如何結合 Transformers.js 實現去除背景和影像標記分割功能。大概效果如下

訪問連結: https://songlh.top/paint-board/

Github: https://github.com/LHRUN/paint-board 歡迎Star ⭐️

Transformers.js 介紹

Transformers.js 是一個功能強大的 JavaScript 庫,基於 Hugging Face 的 Transformers, 可以直接在瀏覽器中執行,而無需依賴伺服器端計算. 這就意味著透過它, 你可以直接在本地執行模型, 能大幅提升效率和降低部署和維護的成本.

目前 Transformers.js 已經在 Hugging Face 上提供了 1000+ 模型, 覆蓋了各個領域, 能滿足你的大多數需求, 如影像處理、文字生成、翻譯、情感分析等任務處理, 你都可以透過 Transformers.js 輕鬆實現. 搜尋模型方式如下:

目前 Transformers.js 的大版本已更新到了 V3, 增加了很多大功能, 具體可以看 Transformers.js v3: WebGPU Support, New Models & Tasks, and More….

我本篇文章加的這兩個功能都是用到了 V3 才有的 WebGpu 支援, 極大的提升了處理速度, 目前的解析都是在毫秒級. 但是需要注意的是, 目前支援 WebGPU 的瀏覽器不太多, 因此建議使用最新版谷歌進行訪問

功能一: 實現去除背景

去除背景我是使用的 Xenova/modnet 模型, 效果如下

處理邏輯可分三步

  1. 初始化狀態, 並載入模型和處理器
  2. 點選按鈕, 載入影像並進行預處理, 然後透過模型生成透明蒙層, 最後根據透明蒙層和你的原始影像透過 canvas 進行畫素對比生成一個去除背景的影像
  3. 介面展示, 這個以你自己的設計任意發揮, 不用以我的為準. 現在比較流行的就是透過一個邊界線來動態展示去除背景前後的對比效果

大概程式碼邏輯如下, React + TS , 具體可以檢視我專案的原始碼, 原始碼位置 src/components/boardOperation/uploadImage/index.tsx

import { useState, FC, useRef, useEffect, useMemo } from 'react'
import {
  env,
  AutoModel,
  AutoProcessor,
  RawImage,
  PreTrainedModel,
  Processor
} from '@huggingface/transformers'

const REMOVE_BACKGROUND_STATUS = {
  LOADING: 0, // 模型載入中
  NO_SUPPORT_WEBGPU: 1, // 不支援
  LOAD_ERROR: 2, // 載入失敗
  LOAD_SUCCESS: 3, // 載入成功
  PROCESSING: 4, // 處理中
  PROCESSING_SUCCESS: 5 // 處理成功
}

type RemoveBackgroundStatusType =
  (typeof REMOVE_BACKGROUND_STATUS)[keyof typeof REMOVE_BACKGROUND_STATUS]

const UploadImage: FC<{ url: string }> = ({ url }) => {
  const [removeBackgroundStatus, setRemoveBackgroundStatus] =
    useState<RemoveBackgroundStatusType>()
  const [processedImage, setProcessedImage] = useState('')

  const modelRef = useRef<PreTrainedModel>()
  const processorRef = useRef<Processor>()

  const removeBackgroundBtnTip = useMemo(() => {
    switch (removeBackgroundStatus) {
      case REMOVE_BACKGROUND_STATUS.LOADING:
        return '去除背景功能載入中'
      case REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU:
        return '本瀏覽器不支援WebGPU, 要使用去除背景功能請使用最新版谷歌瀏覽器'
      case REMOVE_BACKGROUND_STATUS.LOAD_ERROR:
        return '去除背景功能載入失敗'
      case REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS:
        return '去除背景功能載入成功'
      case REMOVE_BACKGROUND_STATUS.PROCESSING:
        return '去除背景處理中'
      case REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS:
        return '去除背景處理成功'
      default:
        return ''
    }
  }, [removeBackgroundStatus])

  useEffect(() => {
    ;(async () => {
      try {
        if (removeBackgroundStatus === REMOVE_BACKGROUND_STATUS.LOADING) {
          return
        }
        setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOADING)

        // 檢查 WebGPU 支援
        if (!navigator?.gpu) {
          setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU)
          return
        }
        const model_id = 'Xenova/modnet'
        if (env.backends.onnx.wasm) {
          env.backends.onnx.wasm.proxy = false
        }

        // 載入模型和處理器
        modelRef.current ??= await AutoModel.from_pretrained(model_id, {
          device: 'webgpu'
        })
        processorRef.current ??= await AutoProcessor.from_pretrained(model_id)
        setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS)
      } catch (err) {
        console.log('err', err)
        setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_ERROR)
      }
    })()
  }, [])

  const processImages = async () => {
    const model = modelRef.current
    const processor = processorRef.current

    if (!model || !processor) {
      return
    }

    setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING)

    // 載入影像
    const img = await RawImage.fromURL(url)

    // 預處理影像
    const { pixel_values } = await processor(img)

    // 生成影像蒙版
    const { output } = await model({ input: pixel_values })
    const maskData = (
      await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(
        img.width,
        img.height
      )
    ).data

    // 建立一個新的 canvas
    const canvas = document.createElement('canvas')
    canvas.width = img.width
    canvas.height = img.height
    const ctx = canvas.getContext('2d') as CanvasRenderingContext2D

    // 繪製原始影像
    ctx.drawImage(img.toCanvas(), 0, 0)

    // 更新蒙版區域
    const pixelData = ctx.getImageData(0, 0, img.width, img.height)
    for (let i = 0; i < maskData.length; ++i) {
      pixelData.data[4 * i + 3] = maskData[i]
    }
    ctx.putImageData(pixelData, 0, 0)

    // 儲存新圖片
    setProcessedImage(canvas.toDataURL('image/png'))
    setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS)
  }

   // 介面展示
  return (
    <div className="card shadow-xl">
      <button
        className={`btn btn-primary btn-sm ${
          ![
            REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS,
            REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS,
            undefined
          ].includes(removeBackgroundStatus)
            ? 'btn-disabled'
            : ''
        }`}
        onClick={processImages}
      >
        去除背景
      </button>
      <div className="text-xs text-base-content mt-2 flex">
        {removeBackgroundBtnTip}
      </div>
      <div className="relative mt-4 border border-base-content border-dashed rounded-lg overflow-hidden">
        <img
          className={`w-[50vw] max-w-[400px] h-[50vh] max-h-[400px] object-contain`}
          src={url}
        />
        {processedImage && (
          <img
            className={`w-full h-full absolute top-0 left-0 z-[2] object-contain`}
            src={processedImage}
          />
        )}
      </div>
    </div>
  )
}

export default UploadImage

功能二: 實現影像標記分割

影像標記分割我是透過 Xenova/slimsam-77-uniform 模型實現. 具體效果如下, 在你載入成功後可以點選圖片, 然後根據你點選的座標生成分割效果.

處理邏輯可分五步

  1. 初始化狀態, 並載入模型和處理器
  2. 獲取影像並載入, 然後儲存影像載入資料和嵌入資料
  3. 監聽影像點選事件, 記錄點選資料, 分為正標記和負標記, 每次點選後根據點選資料進行解碼生成蒙層資料, 然後根據蒙層資料繪製分割效果
  4. 介面展示, 這個以你自己的設計任意發揮, 不用以我的為準
  5. 點選儲存圖片, 根據蒙層畫素資料, 匹配出原始影像的資料, 然後透過 canvas 繪製匯出

大概程式碼邏輯如下, React + TS , 具體可以檢視我專案的原始碼, 原始碼位置 src/components/boardOperation/uploadImage/imageSegmentation.tsx

import { useState, useRef, useEffect, useMemo, MouseEvent, FC } from 'react'
import {
  SamModel,
  AutoProcessor,
  RawImage,
  PreTrainedModel,
  Processor,
  Tensor,
  SamImageProcessorResult
} from '@huggingface/transformers'

import LoadingIcon from '@/components/icons/loading.svg?react'
import PositiveIcon from '@/components/icons/boardOperation/image-segmentation-positive.svg?react'
import NegativeIcon from '@/components/icons/boardOperation/image-segmentation-negative.svg?react'

interface MarkPoint {
  position: number[]
  label: number
}

// 處理狀態
const SEGMENTATION_STATUS = {
  LOADING: 0, // 模型載入中
  NO_SUPPORT_WEBGPU: 1, // 不支援 WebGPU
  LOAD_ERROR: 2, // 模型載入失敗
  LOAD_SUCCESS: 3, // 模型載入成功
  PROCESSING: 4, // 影像處理中
  PROCESSING_SUCCESS: 5 // 影像處理成功
}

type SegmentationStatusType =
  (typeof SEGMENTATION_STATUS)[keyof typeof SEGMENTATION_STATUS]

const ImageSegmentation: FC<{ url: string }> = ({ url }) => {
  const [markPoints, setMarkPoints] = useState<MarkPoint[]>([])
  const [segmentationStatus, setSegmentationStatus] =
    useState<SegmentationStatusType>()
  const [pointStatus, setPointStatus] = useState<boolean>(true)

  const maskCanvasRef = useRef<HTMLCanvasElement>(null) // 分割蒙版
  const modelRef = useRef<PreTrainedModel>() // 模型
  const processorRef = useRef<Processor>() // 處理器
  const imageInputRef = useRef<RawImage>() // 原始影像
  const imageProcessed = useRef<SamImageProcessorResult>() // 處理過的影像
  const imageEmbeddings = useRef<Tensor>() // 嵌入資料

  const segmentationTip = useMemo(() => {
    switch (segmentationStatus) {
      case SEGMENTATION_STATUS.LOADING:
        return '影像分割功能載入中'
      case SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU:
        return '本瀏覽器不支援WebGPU, 要使用影像分割功能請使用最新版谷歌瀏覽器'
      case SEGMENTATION_STATUS.LOAD_ERROR:
        return '影像分割功能載入失敗'
      case SEGMENTATION_STATUS.LOAD_SUCCESS:
        return '影像分割功能載入成功'
      case SEGMENTATION_STATUS.PROCESSING:
        return '處理影像中'
      case SEGMENTATION_STATUS.PROCESSING_SUCCESS:
        return '影像處理成功, 可點選影像進行標記, 綠色蒙層區域就是分割區域'
      default:
        return ''
    }
  }, [segmentationStatus])

  // 1. 載入模型和處理器
  useEffect(() => {
    ;(async () => {
      try {
        if (segmentationStatus === SEGMENTATION_STATUS.LOADING) {
          return
        }

        setSegmentationStatus(SEGMENTATION_STATUS.LOADING)
        if (!navigator?.gpu) {
          setSegmentationStatus(SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU)
          return
        }
        const model_id = 'Xenova/slimsam-77-uniform'
        modelRef.current ??= await SamModel.from_pretrained(model_id, {
          dtype: 'fp16', // or "fp32"
          device: 'webgpu'
        })
        processorRef.current ??= await AutoProcessor.from_pretrained(model_id)

        setSegmentationStatus(SEGMENTATION_STATUS.LOAD_SUCCESS)
      } catch (err) {
        console.log('err', err)
        setSegmentationStatus(SEGMENTATION_STATUS.LOAD_ERROR)
      }
    })()
  }, [])

  // 2. 處理影像
  useEffect(() => {
    ;(async () => {
      try {
        if (
          !modelRef.current ||
          !processorRef.current ||
          !url ||
          segmentationStatus === SEGMENTATION_STATUS.PROCESSING
        ) {
          return
        }
        setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING)
        clearPoints()

        imageInputRef.current = await RawImage.fromURL(url)
        imageProcessed.current = await processorRef.current(
          imageInputRef.current
        )
        imageEmbeddings.current = await (
          modelRef.current as any
        ).get_image_embeddings(imageProcessed.current)

        setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING_SUCCESS)
      } catch (err) {
        console.log('err', err)
      }
    })()
  }, [url, modelRef.current, processorRef.current])

  // 更新蒙層效果
  function updateMaskOverlay(mask: RawImage, scores: Float32Array) {
    const maskCanvas = maskCanvasRef.current
    if (!maskCanvas) {
      return
    }
    const maskContext = maskCanvas.getContext('2d') as CanvasRenderingContext2D

    // 更新 canvas 尺寸
    if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
      maskCanvas.width = mask.width
      maskCanvas.height = mask.height
    }

    // 分配蒙層區域
    const imageData = maskContext.createImageData(
      maskCanvas.width,
      maskCanvas.height
    )

    const numMasks = scores.length // 3
    let bestIndex = 0
    for (let i = 1; i < numMasks; ++i) {
      if (scores[i] > scores[bestIndex]) {
        bestIndex = i
      }
    }

    // 填充蒙層顏色
    const pixelData = imageData.data
    for (let i = 0; i < pixelData.length; ++i) {
      if (mask.data[numMasks * i + bestIndex] === 1) {
        const offset = 4 * i
        pixelData[offset] = 101 // r
        pixelData[offset + 1] = 204 // g
        pixelData[offset + 2] = 138 // b
        pixelData[offset + 3] = 255 // a
      }
    }

    // 繪製
    maskContext.putImageData(imageData, 0, 0)
  }

  // 3. 根據點選資料進行解碼
  const decode = async (markPoints: MarkPoint[]) => {
    if (
      !modelRef.current ||
      !imageEmbeddings.current ||
      !processorRef.current ||
      !imageProcessed.current
    ) {
      return
    }

    // 沒有點選資料直接清除分割效果
    if (!markPoints.length && maskCanvasRef.current) {
      const maskContext = maskCanvasRef.current.getContext(
        '2d'
      ) as CanvasRenderingContext2D
      maskContext.clearRect(
        0,
        0,
        maskCanvasRef.current.width,
        maskCanvasRef.current.height
      )
      return
    }

    // 生成解碼所需資料
    const reshaped = imageProcessed.current.reshaped_input_sizes[0]
    const points = markPoints
      .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
      .flat(Infinity)
    const labels = markPoints.map((x) => BigInt(x.label)).flat(Infinity)

    const num_points = markPoints.length
    const input_points = new Tensor('float32', points, [1, 1, num_points, 2])
    const input_labels = new Tensor('int64', labels, [1, 1, num_points])

    // 生成蒙版
    const { pred_masks, iou_scores } = await modelRef.current({
      ...imageEmbeddings.current,
      input_points,
      input_labels
    })

    // 處理蒙版
    const masks = await (processorRef.current as any).post_process_masks(
      pred_masks,
      imageProcessed.current.original_sizes,
      imageProcessed.current.reshaped_input_sizes
    )

    updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data)
  }

  const clamp = (x: number, min = 0, max = 1) => {
    return Math.max(Math.min(x, max), min)
  }

  // 點選影像
  const clickImage = (e: MouseEvent) => {
    if (segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS) {
      return
    }

    const { clientX, clientY, currentTarget } = e
    const { left, top } = currentTarget.getBoundingClientRect()

    const x = clamp(
      (clientX - left + currentTarget.scrollLeft) / currentTarget.scrollWidth
    )
    const y = clamp(
      (clientY - top + currentTarget.scrollTop) / currentTarget.scrollHeight
    )

    const existingPointIndex = markPoints.findIndex(
      (point) =>
        Math.abs(point.position[0] - x) < 0.01 &&
        Math.abs(point.position[1] - y) < 0.01 &&
        point.label === (pointStatus ? 1 : 0)
    )

    const newPoints = [...markPoints]
    if (existingPointIndex !== -1) {
      // 如果當前點選區域存在標記, 則進行刪除
      newPoints.splice(existingPointIndex, 1)
    } else {
      newPoints.push({
        position: [x, y],
        label: pointStatus ? 1 : 0
      })
    }

    setMarkPoints(newPoints)
    decode(newPoints)
  }

  const clearPoints = () => {
    setMarkPoints([])
    decode([])
  }

  return (
    <div className="card shadow-xl overflow-auto">
      <div className="flex items-center gap-x-3">
        <button className="btn btn-primary btn-sm" onClick={clearPoints}>
          清除標記點
        </button>

        <button
          className="btn btn-primary btn-sm"
          onClick={() => setPointStatus(true)}
        >
          {pointStatus ? '正標記' : '負標記'}
        </button>
      </div>
      <div className="text-xs text-base-content mt-2">{segmentationTip}</div>
      <div
        id="test-image-container"
        className="relative mt-4 border border-base-content border-dashed rounded-lg h-[60vh] max-h-[500px] w-fit max-w-[60vw] overflow-x-auto overflow-y-hidden"
        onClick={clickImage}
      >
        {segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS && (
          <div className="absolute z-[3] top-0 left-0 w-full h-full bg-slate-400 bg-opacity-70 flex justify-center items-center">
            <LoadingIcon className="animate-spin" />
          </div>
        )}
        <div className="h-full w-max relative overflow-hidden">
          <img className="h-full max-w-none" src={url} />

          <canvas
            ref={maskCanvasRef}
            className="absolute top-0 left-0 h-full w-full z-[1] opacity-60"
          ></canvas>

          {markPoints.map((point, index) => {
            switch (point.label) {
              case 1:
                return (
                  <PositiveIcon
                    key={index}
                    className="w-[24px] h-[24px] absolute z-[2] -ml-[13px] -mt-[14px] fill-[#FFD401]"
                    style={{
                      top: `${point.position[1] * 100}%`,
                      left: `${point.position[0] * 100}%`
                    }}
                  />
                )
              case 0:
                return (
                  <NegativeIcon
                    key={index}
                    className="w-[24px] h-[24px] absolute z-[2] -ml-[13px] -mt-[14px] fill-[#F44237]"
                    style={{
                      top: `${point.position[1] * 100}%`,
                      left: `${point.position[0] * 100}%`
                    }}
                  />
                )
              default:
                return null
            }
          })}
        </div>
      </div>
    </div>
  )
}

export default ImageSegmentation

總結

感謝你的閱讀。以上就是本文的全部內容,希望這篇文章對你有所幫助,歡迎點贊和 Star 。如果有任何問題,歡迎在評論區進行討論

相關文章