Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
217828c
feat(transformers): upgrade modeling_utils/RoPE/optimization/cache_ut…
wcrzlh May 14, 2025
9cc0d13
feat(transformers): upgrade modeling_utils/RoPE/optimization/cache_ut…
wcrzlh May 15, 2025
c72685e
feat(transformers): upgrade modeling_utils/RoPE/optimization/cache_ut…
wcrzlh May 15, 2025
83026b7
feat(transformers): upgrade modeling_utils/RoPE/optimization/cache_ut…
wcrzlh May 15, 2025
c56d60a
precommit
wcrzlh May 15, 2025
5a33105
fix cache_util
wcrzlh May 15, 2025
2e1804e
Merge pull request #2 from wcrzlh/align_v450
Mark-ZhouWX May 15, 2025
c46ad00
update generation_util, logit_processor, stopping_criteria, candidate…
May 15, 2025
59c6b3d
fix cache_utils
wcrzlh May 16, 2025
896bda6
Merge pull request #3 from wcrzlh/align_v450
Mark-ZhouWX May 16, 2025
5330e24
Add beam_search&For ForCausalLMLoss loss
liuchuting May 16, 2025
7678c0b
Merge pull request #4 from liuchuting/align_v450
Mark-ZhouWX May 16, 2025
a2b8dcd
update modeling_outputs
May 16, 2025
3510304
update beam_search
May 16, 2025
de436a8
pre-commit check
May 16, 2025
94c0bec
add support for dynamic input
May 16, 2025
ae1efcb
add beam search py
May 19, 2025
db3f20e
fix tie_weight dtype mismatch in pynative mode for albert
May 19, 2025
6aa2758
fix beam_search
liuchuting May 20, 2025
2e10097
fix bug of dynamic input
May 20, 2025
288b26f
Merge pull request #6 from liuchuting/align_v450
Mark-ZhouWX May 20, 2025
a7701ae
add _supports_dynamic_input to PretrainedModel
May 20, 2025
9c297a0
Codellama: inference script and llama UT (#1005)
wtomin May 20, 2025
bf2c113
fix pynative synchronize bug for albert[temporal]
May 20, 2025
43793a6
support multimodal for init_static_cache[hack implementation]
May 21, 2025
c28fca7
fix cache_utils bugs
wcrzlh May 21, 2025
886a4c4
Merge pull request #7 from wcrzlh/align_v450
Mark-ZhouWX May 21, 2025
4273221
fix(diffusers): fix bugs about checkpoints loading in `utils/hub_util…
townwish4git May 21, 2025
ff04ea5
fix cumsum does not support int64
May 22, 2025
c82e453
Merge pull request #1 from Mark-ZhouWX/align_v450
iugoood May 22, 2025
9ab5557
Revert "Align v450"
iugoood May 22, 2025
1bba509
Merge pull request #2 from iugoood/revert-1-align_v450
iugoood May 22, 2025
5f7b2bc
Align v450 (#1004)
Mark-ZhouWX May 22, 2025
534f5eb
feat(diffusers/sd_safe): add stable_diffusion_safe (#990)
Cui-yshoho May 22, 2025
e0f9218
Merge branch 'mindspore-lab:master' into master
iugoood May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Safe Stable Diffusion

Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content.

The abstract from the paper is:

*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*

!!! tip

Use the `safety_concept` property of [`StableDiffusionPipelineSafe`] to check and edit the current safety concept:

```python
>>> from mindone.diffusers import StableDiffusionPipelineSafe

>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe", revision="refs/pr/6")
>>> pipeline.safety_concept
'an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child abuse, brutality, cruelty'
```
For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].

There are 4 configurations (`SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONG`, and `SafetyConfig.MAX`) that can be applied:

```python
>>> from mindone.diffusers import StableDiffusionPipelineSafe
>>> from mindone.diffusers.pipelines.stable_diffusion_safe import SafetyConfig

>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe", revision="refs/pr/6")
>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
```

!!! tip

Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!

::: mindone.diffusers.StableDiffusionPipelineSafe

::: mindone.diffusers.pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
65 changes: 65 additions & 0 deletions examples/transformers/codellama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

# CodeLlama

## Overview

The Code Llama models were proposed in [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/).

All Code Llama model checkpoints can be found [here](https://huggingface.co/models?search=code_llama), and the officially released checkpoints at [meta llama org](https://huggingface.co/meta-llama).

This model was contributed by [ArthurZucker](https://huggingface.co/ArthurZ). The original code can be found [here](https://github.com/facebookresearch/llama).

## Checkpoints

CodeLlama checkpoints are restricted. One need to authenticate with Hugging Face to download the checkpoints. We recommend using the following command to authenticate:

```bash
huggingface-cli login
```
Login with your HuggingFace access token with the correct permissions.

Afterwards, you can download the checkpoints using the following command:
```bash
huggingface-cli download --resume-download meta-llama/CodeLlama-7b-hf
```

## Examples

Here's an example usage:

```bash
>>> from transformers import CodeLlamaTokenizer
>>> from mindone.transformers.models.llama import LlamaForCausalLM
>>> import mindspore as ms

>>> tokenizer = CodeLlamaTokenizer.from_pretrained("meta-llama/CodeLlama-7b-hf")
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/CodeLlama-7b-hf", use_flash_attention_2=True, mindspore_dtype=ms.float16) # model weight will be automatically downloaded from huggingface
>>> PROMPT = '''def remove_non_ascii(s: str) -> str:
""" <FILL_ME>
return result
'''
>>> input_ids = ms.Tensor(tokenizer(prompt, return_tensors="np").input_ids, ms.int32)
>>> generated_ids = model.generate(input_ids, max_new_tokens=128, do_sample=False).asnumpy()

>>> filling = tokenizer.batch_decode(generated_ids[:, input_ids.shape[1]:], skip_special_tokens = True)[0]
>>> print(PROMPT.replace("<FILL_ME>", filling))
def remove_non_ascii(s: str) -> str:
""" Remove non-ASCII characters from a string.

Args:
s: The string to remove non-ASCII characters from.

Returns:
The string with non-ASCII characters removed.
"""
result = ""
for c in s:
if ord(c) < 128:
result += c
return result
```
Internally, the tokenizer automatically splits by <FILL_ME> to create a formatted input string following the original training pattern. This is more robust than preparing the pattern yourself as it avoids very hard-to-debug pitfalls like token glueing.

The LLaMA tokenizer is a BPE model based on sentencepiece. One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of a word (e.g., "Banana"), the tokenizer does not prepend the prefix space to the string.

Code Llama has the same architecture as the Llama2 models. For API reference, see the Llama2 documentation page.
97 changes: 97 additions & 0 deletions examples/transformers/codellama/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
import ast
import os
import time

from transformers import CodeLlamaTokenizer

import mindspore as ms

from mindone.transformers.models.llama import LlamaForCausalLM


def run_codellama_generate(args):
print("=====> test_codellama_generate:")
print("=====> Building model...")

s_time = time.time()

tokenizer = CodeLlamaTokenizer.from_pretrained(args.model_path)
model = LlamaForCausalLM.from_pretrained(
args.model_path, use_flash_attention_2=args.use_fa, mindspore_dtype=ms.float16
)

print("=====> Building model done.")

PROMPT = '''def remove_non_ascii(s: str) -> str:
""" <FILL_ME>
return result
'''

prompt = [
PROMPT,
]
input_ids = ms.Tensor(tokenizer(prompt, return_tensors="np").input_ids, ms.int32)

input_kwargs = {}
if args.use_embed_input:
input_kwargs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
else:
input_kwargs["input_ids"] = input_ids

generated_ids = model.generate(**input_kwargs, use_cache=args.use_cache, max_new_tokens=128, do_sample=False)
generated_ids = generated_ids.asnumpy()

filling = tokenizer.batch_decode(generated_ids[:, input_ids.shape[1] :], skip_special_tokens=True)[0].strip()

print(f"=====> input prompt: {prompt}, time cost: {time.time() - s_time:.2f}s")
print("=" * 46 + " Result " + "=" * 46)
print(PROMPT.replace("<FILL_ME>", filling))
print("=" * 100)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="test")
parser.add_argument("--ms_mode", type=int, default=0, help="0 is Graph, 1 is Pynative")
parser.add_argument("--jit_level", type=str, default="O0")
parser.add_argument("--model_path", type=str, default="meta-llama/CodeLlama-7b-hf")
parser.add_argument("--use_fa", type=ast.literal_eval, default=True)
parser.add_argument("--use_cache", type=ast.literal_eval, default=True)
parser.add_argument("--use_embed_input", type=ast.literal_eval, default=False)
args, _ = parser.parse_known_args()

if args.ms_mode == ms.GRAPH_MODE:
if os.environ.get("MS_DEV_RUNTIME_CONF") is None:
os.environ["MS_DEV_RUNTIME_CONF"] = "synchronize:True"
print("WARNING: os environment MS_DEV_RUNTIME_CONF synchronize has not been set, force setting it now.")
else:
if "synchronize:True" not in os.environ.get("MS_DEV_RUNTIME_CONF"):
_old = os.environ.get("MS_DEV_RUNTIME_CONF")
_old.replace("synchronize:False,", "")
_old.replace(",synchronize:False", "")
_old.replace("synchronize:False", "")
_new = "synchronize:True," + _old if len(_old) > 0 else "synchronize:True"
os.environ["MS_DEV_RUNTIME_CONF"] = _new
print("WARNING: os environment MS_DEV_RUNTIME_CONF synchronize has not been set, force setting it now.")

ms.set_context(
mode=ms.GRAPH_MODE,
device_target="Ascend",
jit_config={"jit_level": args.jit_level},
max_device_memory="59GB",
deterministic="ON",
)

elif args.ms_mode == ms.PYNATIVE_MODE:
ms.set_context(
mode=ms.PYNATIVE_MODE,
device_target="Ascend",
pynative_synchronize=True,
max_device_memory="59GB",
deterministic="ON",
)

else:
raise ValueError

run_codellama_generate(args)
2 changes: 2 additions & 0 deletions mindone/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@
"StableDiffusionPAGPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionPipeline",
"StableDiffusionPipelineSafe",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionXLAdapterPipeline",
Expand Down Expand Up @@ -505,6 +506,7 @@
StableDiffusionPAGPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLAdapterPipeline,
Expand Down
2 changes: 2 additions & 0 deletions mindone/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
],
"stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"stable_diffusion_sag": ["StableDiffusionSAGPipeline"],
"stable_diffusion_gligen": [
"StableDiffusionGLIGENPipeline",
Expand Down Expand Up @@ -379,6 +380,7 @@
from .stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
from .stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
from .stable_diffusion_panorama import StableDiffusionPanoramaPipeline
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_sag import StableDiffusionSAGPipeline
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
Expand Down
14 changes: 9 additions & 5 deletions mindone/diffusers/pipelines/stable_diffusion/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,27 @@ def construct(self, clip_input: Tensor, images: Tensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)

special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).float()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).float()

# increase this value to create a stronger `nsfw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0

special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
# special_scores = special_scores.round(decimals=3)
special_scores = ops.round(special_scores, decimals=3)
special_care = ops.any(special_scores > 0, axis=1)
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).tile((1, cos_dist.shape[1]))

concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
# concept_scores = concept_scores.round(decimals=3)
concept_scores = ops.round(concept_scores, decimals=3)
has_nsfw_concepts = ops.any(concept_scores > 0, axis=1)

images[has_nsfw_concepts] = 0.0 # black image
if ops.is_tensor(images):
images[has_nsfw_concepts] = 0.0 # black image
else:
# TODO: if has_nsfw_concepts is tensor and images is array, the images will be wrong.
images[has_nsfw_concepts.numpy()] = 0.0 # black image

return images, has_nsfw_concepts
74 changes: 74 additions & 0 deletions mindone/diffusers/pipelines/stable_diffusion_safe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
import PIL
from PIL import Image

from ...utils import BaseOutput, _LazyModule


@dataclass
class SafetyConfig(object):
WEAK = {
"sld_warmup_steps": 15,
"sld_guidance_scale": 20,
"sld_threshold": 0.0,
"sld_momentum_scale": 0.0,
"sld_mom_beta": 0.0,
}
MEDIUM = {
"sld_warmup_steps": 10,
"sld_guidance_scale": 1000,
"sld_threshold": 0.01,
"sld_momentum_scale": 0.3,
"sld_mom_beta": 0.4,
}
STRONG = {
"sld_warmup_steps": 7,
"sld_guidance_scale": 2000,
"sld_threshold": 0.025,
"sld_momentum_scale": 0.5,
"sld_mom_beta": 0.7,
}
MAX = {
"sld_warmup_steps": 0,
"sld_guidance_scale": 5000,
"sld_threshold": 1.0,
"sld_momentum_scale": 0.5,
"sld_mom_beta": 0.7,
}


_additional_imports = {}
_import_structure = {}

_additional_imports.update({"SafetyConfig": SafetyConfig})

_import_structure.update(
{
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
)


if TYPE_CHECKING:
from .pipeline_output import StableDiffusionSafePipelineOutput
from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
from .safety_checker import SafeStableDiffusionSafetyChecker

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)

for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import List, Optional, Union

import numpy as np
import PIL.Image

from ...utils import BaseOutput


@dataclass
class StableDiffusionSafePipelineOutput(BaseOutput):
"""
Output class for Safe Stable Diffusion pipelines.

Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
(nsfw) content, or `None` if no safety check was performed or no images were flagged.
applied_safety_concept (`str`)
The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
"""

images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
applied_safety_concept: Optional[str]
Loading