KerasHub统一、全面的预训练模型库

描述

深度学习领域正在迅速发展,在处理各种类型的任务中,预训练模型变得越来越重要。Keras 以其用户友好型 API 和对易用性的重视而闻名,始终处于这一动向的前沿。Keras 拥有专用的内容库,如用于文本模型的 KerasNLP 和用于计算机视觉模型的 KerasCV。

然而,随着模型使各模态之间的界限越来越模糊 (想象一下强大的聊天 LLM 具有图像输入功能或是在视觉任务中利用文本编码器),维持这些独立的领域变得不那么实际。NLP 和 CV 之间的区别可能会阻碍真正多模态模型的发展和部署,从而导致冗余的工作和碎片化的用户体验。

 

为了解决这个问题,我们很高兴地宣布 Keras 生态系统迎来重大变革: 隆重推出 KerasHub,一个统一、全面的预训练模型库,简化了对前沿 NLP 和 CV 架构的访问。KerasHub 是一个中央存储库,您可以在稳定且熟悉的 Keras 框架内无缝探索和使用最先进的模型,例如用于文本分析的 BERT 以及用于图像分类的 EfficientNet。

KerasHub https://keras.io/keras_hub/

统一的开发者体验

这种统一不仅简化了对模型的探索和使用,还有助于打造更具凝聚力的生态系统。通过 KerasHub,您可以利用高级功能,例如轻松的发布和共享模型、用于优化资源效率的 LoRA 微调、用于优化性能的量化,以及用于处理大规模数据集的强大多主机训练,所有这些功能都适用于各种模态。这标志着在普及强大的 AI 工具以及加速开发创新型多模态应用方面迈出了重要一步。

KerasHub 入门步骤

首先在您的系统上安装 KerasHub,您可以在其中探索大量现成的模型和主流架构的不同实现方式。然后,您就可以轻松地将这些预训练的模型加载并整合到自己的项目中,并根据您的具体需求对其进行微调,以获得最佳性能。

现成的模型 https://keras.io/api/keras_hub/models/

安装 KerasHub

要安装带有 Keras 3 的 KerasHub 最新版本,只需运行以下代码:

 

$ pip install --upgrade keras-hub
现在,您可以开始探索可用的模型。使用 Keras 3 开始工作的标准环境设置在开始使用 KerasHub 时并不需要任何改变:
import os


# Define the Keras 3 backend you want to use - "jax", "tensorflow" or "torch"
os.environ["KERAS_BACKEND"] = "jax"


# Import Keras 3 and KerasHub modules
import keras
import keras_hub

 

 

通过 KerasHub 使用

计算机视觉和自然语言模型

现在,您可以通过 KerasHub 访问和使用 Keras 3 生态系统中的模型。以下是一些示例:

Gemma

Gemma 是由 Google 开发的一系列前沿且易于使用的开放模型。依托于与 Gemini 模型相同的研究和技术,Gemma 的基础模型在各种文本生成任务中表现出色,包括回答问题、总结信息以及进行逻辑推理。此外,您还可以针对特定需求自定义模型。  

Gemma https://ai.google.dev/gemma/docs/base

在此示例中,您可以使用 Keras 和 KerasHub 加载并开始使用 Gemma 2 2B 参数生成内容。有关 Gemma 变体的更多详细信息,请查看 Kaggle 上的 Gemma 模型卡。

 

# Load Gemma 2 2B preset from Kaggle models 
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")


# Start generating contents with Gemma 2 2B
gemma_lm.generate("Keras is a", max_length=32)

 

 

Gemma 模型卡 https://www.kaggle.com/models/google/gemma/

PaliGemma

PaliGemma 是一款紧凑型的开放模型,可以理解图像和文本。PaliGemma 从 PaLI-3 中汲取灵感,以 SigLIP 视觉模型和 Gemma 语言模型等开源组件为基础,可以针对有关图像的问题提供详细且富有洞察力的答案。因此,该模型可以更深入地了解视觉内容,从而实现诸多功能,例如为图像和短视频生成描述、识别对象甚至理解图像中的文本。

 

import os


# Define the Keras 3 backend you want to use - "jax", "tensorflow" or "torch"
os.environ["KERAS_BACKEND"] = "jax"


# Import Keras 3 and KerasHub modules
import keras
import keras_hub
from keras.utils import get_file, load_img, img_to_array




# Import PaliGemma 3B fine tuned with 224x224 images
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
    "pali_gemma_3b_mix_224"
)


# Download a test image and prepare it for usage with KerasHub
url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
img_path = get_file(origin=url)
img = img_to_array(load_img(image_path))


# Create the prompt with the question about the image
prompt = 'answer where is the cow standing?'


# Generate the contents with PaliGemma
output = pali_gemma_lm.generate(
    inputs={
        "images": img,
        "prompts": prompt,
    }
)

 

 

PaliGemma https://ai.google.dev/gemma/docs/paligemma

PaLI-3 https://arxiv.org/abs/2310.09199

SigLIP 视觉模型 https://arxiv.org/abs/2303.15343

Gemma 语言模型 https://arxiv.org/abs/2403.08295

有关 Keras 3 上可用的预训练模型的更多详细信息,请在 Kaggle 上查看 Keras 中的模型列表。  

Kaggle 上查看 Keras 中的模型列表 https://www.kaggle.com/organizations/keras/models

Stability.ai Stable Diffusion 3

您也可以使用计算机视觉模型。例如,您可以通过 KerasHub 使用 stability.ai Stable Diffusion 3:

 

from PIL import Image
from keras.utils import array_to_img
from keras_hub.models import StableDiffusion3TextToImage


text_to_image = StableDiffusion3TextToImage.from_preset(
    "stable_diffusion_3_medium",
    height=1024,
    width=1024,
    dtype="float16",
)


# Generate images with SD3
image = text_to_image.generate(
    "photograph of an astronaut riding a horse, detailed, 8k",
)


# Display the generated image
img = array_to_img(image)
img

 

 

Stable Diffusion 3 https://stability.ai/news/stable-diffusion-3

有关 Keras 3 上可用的预训练计算机视觉模型的更多详细信息,请查看 Keras 中的模型列表。  

Keras 中的模型列表 https://keras.io/api/keras_hub/models/

对于 KerasNLP 开发者而言,

有哪些变化?

从 KerasNLP 到 KerasHub 的过渡是一个简单的过程。只需要将 import 语句从 keras_nlp 更新为 keras_hub。

示例: 以前,您可能需要导入 keras_nlp 才能使用 BERT 模型,如下所示

 

import keras_nlp


# Load a BERT model 
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)
现在,您只需调整 import,即可使用 KerasHub:
import keras_hub


# Load a BERT model 
classifier = keras_hub.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)
     

 

 

对于 KerasCV 开发者而言,

有哪些变化?

如果您当前是 KerasCV 用户,更新到 KerasHub 能够为您带来以下好处:

简化模型加载: KerasHub 为加载模型提供了统一的 API,如果您同时使用 KerasCV 和 KerasNLP,这可以简化您的代码。

框架灵活性: 如果您有兴趣探索 JAX 或 PyTorch 等不同框架,KerasHub 可以让您更轻松地将这些框架与 KerasCV 和 KerasNLP 模型结合起来使用。

集中式存储库: 借助 KerasHub 的统一模型存储库,您可以更轻松地查找和访问模型,未来还可以在其中添加新架构。

如何使我的代码适配 KerasHub?

模型

目前,我们正在将 KerasCV 模型迁移到 KerasHub。虽然大多数模型已经可用,但有些仍在迁移中。请注意,Centerpillar 模型不会被迁移。您应该能够在 KerasHub 使用任何视觉模型,方法如下:

 

import keras_hub


# Load a model using preset
Model = keras_hub.models..from_preset('preset_name`)


# or load a custom model by specifying the backbone and preprocessor
Model = keras_hub.models.(backbone=backbone, preprocessor=preprocessor)

 

 

Centerpillar https://www.kaggle.com/models/keras/centerpillar

KerasHub 为 KerasCV 开发者带来了激动人心的新功能,提供了更高的灵活性和扩展能力。其中包括:

内置预处理

每个模型都配备了一个定制的预处理器,用于处理包括调整大小、重新缩放等常规任务,从而简化您的工作流程。   在此之前,预处理输入是在向模型提供输入之前手动执行的。

 

# Preprocess inputs for example
def preprocess_inputs(image, label):
    # Resize rescale or do more preprocessing on inputs
    return preprocessed_inputs
backbone = keras_cv.models.ResNet50V2Backbone.from_preset(
    "resnet50_v2_imagenet",
)
model = keras_cv.models.ImageClassifier(
    backbone=backbone,
    num_classes=4,
)
output = model(preprocessed_input)
  目前,任务模型的预处理已集成到现成的预设中。预处理器会对输入进行预处理,对样本图像进行大小调整和重新缩放。预处理器是任务模型的内在组件。尽管如此,开发者还是可以选择使用个性化的预处理器。
classifier = keras_hub.models.ImageClassifier.from_preset('resnet_18_imagenet')
classifier.predict(inputs)

 

 

损失函数

与增强层类似,以前 KerasCV 中的损失函数现在可在 Keras 中通过 keras.losses. 使用。例如,如果您当前正在使用 FocalLoss 函数:

 

import keras
import keras_cv


keras_cv.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

 

 

FocalLoss 函数 https://keras.io/api/keras_cv/losses/focal_loss/

您只需调整损失函数定义代码,使用 keras.losses 而不是 keras_cv.losses:

 

import keras


keras.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

 

 

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分