当前位置: 首页 > article >正文

lit-llama代码解析

https://github.com/Lightning-AI/lit-llama/blob/main/README.md

下载的时候会报错误,因为网不行,一种方法就是多次尝试,另一种方法是终端连上代理下载

pycharm连接hugging face等网站_hugging face怎么连接-CSDN博客

根据指引下载权重

下载完权重运行:python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B --model_size 7B

转化为.pth文件 

跟着readme/howto教程量化或进行其他操作

warning

UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ..\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:455.)
  y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

https://github.com/comfyanonymous/ComfyUI/issues/3202

分析generate

# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import sys
import time
import warnings
from pathlib import Path
from typing import Optional

import lightning as L
import torch
print(torch.cuda.is_available())
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import lazy_load, llama_model_lookup, quantization


@torch.no_grad()
def generate(
    model: LLaMA,
    idx: torch.Tensor,
    max_new_tokens: int,
    *,
    max_seq_length: Optional[int] = None,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    eos_id: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

    The implementation of this function is modified from A. Karpathy's nanoGPT.

    Args:
        model: The model to use.
        idx: Tensor of shape (T) with indices of the prompt sequence.
        max_new_tokens: The number of new tokens to generate.
        max_seq_length: The maximum sequence length allowed.
        temperature: Scales the predicted logits by 1 / temperature
        top_k: If specified, only sample among the tokens with the k highest probabilities
        eos_id: If specified, stop generating any more token once the <eos> token is triggered
    """
    # create an empty tensor of the expected final shape and fill in the current tokens
    T = idx.size(0)
    T_new = T + max_new_tokens
    if max_seq_length is None:
        max_seq_length = min(T_new, model.config.block_size)

    device, dtype = idx.device, idx.dtype
    # create an empty tensor of the expected final shape and fill in the current tokens
    empty = torch.empty(T_new, dtype=dtype, device=device)
    empty[:T] = idx
    idx = empty
    input_pos = torch.arange(0, T, device=device)

    if idx.device.type == "xla":
        import torch_xla.core.xla_model as xm

        xm.mark_step()

    # generate max_new_tokens tokens
    for _ in range(max_new_tokens):
        x = idx.index_select(0, input_pos).view(1, -1)

        # forward
        logits = model(x, max_seq_length, input_pos)
        logits = logits[0, -1] / temperature

        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

        probs = torch.nn.functional.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

        # advance
        input_pos = input_pos[-1:] + 1

        if idx.device.type == "xla":
            xm.mark_step()

        # concatenate the new generation
        idx = idx.index_copy(0, input_pos, idx_next)

        # if <eos> token is triggered, return the output (stop generation)
        if idx_next == eos_id:
            return idx[:input_pos]  # include the EOS token

    return idx


def main(
    prompt: str = "Hello, my name is",
    *,
    num_samples: int = 1,
    max_new_tokens: int = 50,
    top_k: int = 200,
    temperature: float = 0.8,
    checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
    tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
    quantize: Optional[str] = None,
) -> None:
    """Generates text samples based on a pre-trained LLaMA model and tokenizer.

    Args:
        prompt: The prompt string to use for generating the samples.
        num_samples: The number of text samples to generate.(Its effect is overridden by `max_new_tokens`, if also set.)
        max_new_tokens: The number of generation steps to take.(number of generate tokens )
        top_k: The number of top most probable tokens to consider in the sampling process.
        temperature: A value controlling the randomness of the sampling process. Higher values result in more random
            samples.
        checkpoint_path: The checkpoint path to load.
        tokenizer_path: The tokenizer path to load.
        quantize: Whether to quantize the model and using which method:
            ``"llm.int8"``: LLM.int8() mode,
            ``"gptq.int4"``: GPTQ 4-bit mode.
    """
    assert checkpoint_path.is_file(), checkpoint_path
    assert tokenizer_path.is_file(), tokenizer_path

    precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
    fabric = L.Fabric(devices=1, precision=precision)

    print("Loading model ...", file=sys.stderr)
    t0 = time.time()
    with lazy_load(checkpoint_path) as checkpoint:
        name = llama_model_lookup(checkpoint)

        with fabric.init_module(empty_init=True), quantization(mode=quantize):
            model = LLaMA.from_name(name)

        model.load_state_dict(checkpoint)
    print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

    model.eval()
    model = fabric.setup(model)

    tokenizer = Tokenizer(tokenizer_path)
    encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
    prompt_length = encoded.size(0)

    L.seed_everything(1234)
    for i in range(num_samples):
        t0 = time.perf_counter()
        y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
        t = time.perf_counter() - t0

        model.reset_cache()
        print(tokenizer.decode(y))
        tokens_generated = y.size(0) - prompt_length
        print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
    if fabric.device.type == "cuda":
        print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)


if __name__ == "__main__":
    from jsonargparse import CLI

    torch.set_float32_matmul_precision("high")
    warnings.filterwarnings(
        # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
        "ignore", 
        message="ComplexHalf support is experimental and many operators don't support it yet"
    )
    warnings.filterwarnings(
        # Triggered in bitsandbytes/autograd/_functions.py:298
        "ignore", 
        message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
    )
    CLI(main)

main()

"""Generates text samples based on a pre-trained LLaMA model and tokenizer.

Args:
    prompt: The prompt string to use for generating the samples.
    num_samples: The number of text samples to generate.(Its effect is overridden by `max_new_tokens`, if also set.)
    max_new_tokens: The number of generation steps to take.(number of generate tokens )
    top_k: The number of top most probable tokens to consider in the sampling process.
    temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples.
    checkpoint_path: The checkpoint path to load.
    tokenizer_path: The tokenizer path to load.
    quantize: Whether to quantize the model and using which method:
        ``"llm.int8"``: LLM.int8() mode,
        ``"gptq.int4"``: GPTQ 4-bit mode.
"""


https://zhuanlan.zhihu.com/p/657886517

Fabric()

r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.
    Fabric 加速你的 PyTorch 训练或推理代码,所需的更改最小。
    
    - Automatic placement of models and data onto the device.
    - 自动将模型和数据放置到设备上。
    
    - Automatic support for mixed and double precision (smaller memory footprint).
    - 自动支持混合精度和双精度(较小的内存占用)。
    
    - Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
      (data-parallel training, sharded training, etc.).
    - 在硬件(CPU、GPU、TPU)和分布式训练策略(数据并行训练、分片训练等)之间无缝切换。
    
    - Automated spawning of processes, no launch utilities required.
    - 自动生成进程,无需启动工具。
    
    - Multi-node support.
    - 支持多节点训练。

    Args:
        accelerator: The hardware to run on. Possible choices are:
            ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
        accelerator: 运行的硬件。可能的选择有:
            ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``。
        
        strategy: Strategy for how to run across multiple devices. Possible choices are:
            ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
        strategy: 跨多个设备运行的策略。可能的选择有:
            ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``。
        
        devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
            The value applies per node.
        devices: 训练时使用的设备数量(``int``),或要训练的 GPU(``list`` 或 ``str``),或 ``"auto"``。
            该值适用于每个节点。
        
        num_nodes: Number of GPU nodes for distributed training.
        num_nodes: 用于分布式训练的 GPU 节点数量。
        
        precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
            or bfloat16 precision AMP (``"bf16-mixed"``).
        precision: 双精度(``"64"``),全精度(``"32"``),半精度 AMP(``"16-mixed"``),
            或 bfloat16 精度 AMP(``"bf16-mixed"``)。
        
        plugins: One or several custom plugins
        plugins: 一个或多个自定义插件
        
        callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
            can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
        callbacks: 单个回调或回调列表。回调可以包含任何用户可以通过 :meth:`~lightning.fabric.fabric.Fabric.call` 调用的任意方法。
        
        loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more
            information.
        loggers: 单个日志记录器或日志记录器列表。有关更多信息,请参见 :meth:`~lightning.fabric.fabric.Fabric.log`。
"""

lazy_load()

定义了一个名为 lazy_load 的类,它用于延迟加载和管理一个 PyTorch 文件:

lazy_load 类
__init__ 方法
python
def __init__(self, fn):
    self.zf = torch._C.PyTorchFileReader(str(fn))
    with BytesIO(self.zf.get_record("data.pkl")) as pkl:
        mup = LazyLoadingUnpickler(pkl, self)
        self.sd = mup.load()
self.zf = torch._C.PyTorchFileReader(str(fn)):

创建一个 PyTorchFileReader 实例,用于读取指定文件 (fn) 的内容。这个文件是 PyTorch 保存的文件,通常是 .pt 或 .pth 文件。
str(fn) 确保文件路径被正确转换为字符串。
with BytesIO(self.zf.get_record("data.pkl")) as pkl::

从 PyTorchFileReader 中提取名为 "data.pkl" 的记录,并用 BytesIO 创建一个内存中的字节流对象 pkl。
BytesIO 用于在内存中读写二进制数据。
mup = LazyLoadingUnpickler(pkl, self):

创建一个 LazyLoadingUnpickler 实例 mup,它负责处理 pkl 中的数据。这里假设 LazyLoadingUnpickler 是自定义的类,用于延迟加载和解码 Pickle 数据。
self.sd = mup.load():

调用 mup.load() 方法来加载数据,并将结果存储在 self.sd 属性中。这个过程可能会涉及到数据的反序列化。
__enter__ 方法
python
def __enter__(self):
    return self.sd
这个方法允许 lazy_load 实例在上下文管理器(with 语句)中使用。__enter__ 返回 self.sd,使得 with 语句块内部可以直接访问加载的数据。
__exit__ 方法
python
def __exit__(self, exc_type, exc_val, exc_tb):
    del self.zf  # I don't think there is a way to force closing...
    self.zf = None
这个方法用于处理退出上下文管理器时的清理工作。
del self.zf: 尝试删除 self.zf 对象。由于 self.zf 是一个 PyTorchFileReader 实例,删除对象的作用是释放相关资源。
self.zf = None: 另一种释放资源的方式,将 self.zf 设置为 None,以确保它不再被引用。
总结
这个类的设计用于懒加载 PyTorch 文件中的数据。它实现了上下文管理协议,使得数据可以在 with 语句块中方便地访问,并且在退出时尝试释放相关资源。

LazyLoadingUnpickler()

定义了一个 LazyLoadingUnpickler 类,继承自 pickle.Unpickler,用于处理 PyTorch 对象的延迟加载。以下是对每个部分的详细解释:

__init__ 方法
python
def __init__(self, file, zipfile_context):
    super().__init__(file)
    self.zipfile_context = zipfile_context
file: 传入的文件对象(通常是一个字节流),用于反序列化。
zipfile_context: 额外的上下文信息,用于延迟加载的实现。这通常是一个包含 PyTorch 文件读取信息的对象。
super().__init__(file): 调用父类 pickle.Unpickler 的初始化方法,传入文件对象。
self.zipfile_context: 保存额外的上下文信息,用于稍后延迟加载。
find_class 方法
python
def find_class(self, module, name):
    res = super().find_class(module, name)
    if module == "torch._utils" and name == "_rebuild_tensor_v2":
        return functools.partial(
            NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self
        )
    elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
        return functools.partial(
            NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self
        )
    elif module == "torch._utils" and name == "_rebuild_parameter":
        return functools.partial(
            NotYetLoadedTensor.rebuild_parameter, archiveinfo=self
        )
    return res
super().find_class(module, name): 调用父类的 find_class 方法,查找并返回指定模块和类名的类。
模块和类名检查:
当模块是 "torch._utils" 且类名是 "_rebuild_tensor_v2" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_tensor_v2 方法,并传入 archiveinfo=self。
当模块是 "torch._tensor" 且类名是 "_rebuild_from_type_v2" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_from_type_v2 方法。
当模块是 "torch._utils" 且类名是 "_rebuild_parameter" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_parameter 方法。
functools.partial: 允许创建一个新的函数,其中一些参数已经预先指定,这里是为了在实际调用时延迟具体的处理逻辑。
返回值: 如果模块和类名不匹配,返回父类的结果。
persistent_load 方法
python
def persistent_load(self, pid):
    name, cls, fn, device, size = pid
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
    s.archiveinfo = pid
    return s
pid: 一个包含多个信息的元组 (name, cls, fn, device, size),用于标识持久化数据的加载信息。
warnings.catch_warnings(): 捕获并管理警告信息。
warnings.simplefilter("ignore"): 忽略警告信息,以便在加载过程中不会产生干扰。
torch.storage.TypedStorage(dtype=cls().dtype, device="meta"): 创建一个 TypedStorage 对象,指定数据类型和设备。device="meta" 表示数据存储在元数据设备中,实际上并没有分配真实的存储空间。
s.archiveinfo = pid: 将持久化标识信息存储到 TypedStorage 对象中。
返回值: 返回创建的 TypedStorage 对象。
总结
LazyLoadingUnpickler 主要用于在反序列化 PyTorch 对象时实现延迟加载。这种方法使得在加载大数据文件时可以更高效地管理内存和计算资源。find_class 方法用于动态创建用于延迟加载的对象,而 persistent_load 方法则用于处理持久化存储数据的加载。

llama_model_lookup() 

init_module() 

def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:
    """Instantiate the model and its parameters under this context manager to reduce peak memory usage.
在这个上下文管理器下实例化模型及其参数,以减少峰值内存使用。

The parameters get created on the device and with the right data type right away without wasting memory being allocated unnecessarily.
参数会直接在设备上创建,并且使用正确的数据类型,从而避免了不必要的内存分配浪费。

Args:
参数:

empty_init: Whether to initialize the model with empty weights (uninitialized memory).
empty_init: 是否使用空权重(未初始化的内存)来初始化模型。

If ``None``, the strategy will decide. Some strategies may not support all options.
如果``None``,则策略将决定。一些策略可能不支持所有选项。

Set this to ``True`` if you are loading a checkpoint into a large model.
如果你正在将检查点加载到大型模型中,将其设置为``True``。
    """
    self._validate_launched()
    return self._strategy.module_init_context(empty_init=empty_init)
module_init_context()  
 def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
        """A context manager wrapping the model instantiation.
一个包装模型实例化的上下文管理器。

Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.
在这里,策略可以控制模型参数的创建方式(设备、数据类型)或对模型应用其他修补。

Args:
参数:

empty_init: Whether to initialize the model with empty weights (uninitialized memory).
empty_init: 是否使用空权重(未初始化的内存)来初始化模型。

If ``None``, the strategy will decide. Some strategies may not support all options.
如果``None``,则策略将决定。一些策略可能不支持所有选项。
        """
        precision_module_ctx = self.precision.module_init_context()
        stack = ExitStack()
        stack.enter_context(self.root_device)
        stack.enter_context(_EmptyInit(enabled=bool(empty_init)))
        stack.enter_context(precision_module_ctx)
        return stack

quantization() 

@contextmanager
def quantization(mode: str = None):
    quantized_linear_cls = None
    if mode == 'llm.int8':
        from .quantization import Linear8bitLt
        quantized_linear_cls = Linear8bitLt
    elif mode == 'gptq.int4':
        from .quantization import ColBlockQuantizedLinear
        quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
    elif mode == 'gptq.int8':
        from .quantization import ColBlockQuantizedLinear
        quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
    elif mode is not None:
        raise ValueError(f"Unknown quantization mode: {mode}")

    enabled = mode is not None
    torch_linear_cls = torch.nn.Linear
    if enabled:
        torch.nn.Linear = quantized_linear_cls
    yield
    if enabled:
        torch.nn.Linear = torch_linear_cls

model 

setup() 

    def setup(
        self,
        module: nn.Module,
        *optimizers: Optimizer,
        move_to_device: bool = True,
        _reapply_compile: bool = True,
    ) -> Any:  # no specific return because the way we want our API to look does not play well with mypy
        r"""Set up a model and its optimizers for accelerated training.
为加速训练设置模型及其优化器。

Args:
参数:

module: A :class:`torch.nn.Module` to set up
module: 要设置的 :class:`torch.nn.Module`

*optimizers: The optimizer(s) to set up (no optimizers is also possible)
*optimizers: 要设置的优化器(也可以不设置优化器)

move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
move_to_device: 如果设置为``True``(默认值),则将模型移动到正确的设备。设置为``False`` 
    and alternatively use :meth:`to_device` manually.
    并可以手动使用 :meth:`to_device`。

_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
_reapply_compile: 如果``True``(默认值),且模型之前已``torch.compile``,则
    corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
    相应的 :class:`~torch._dynamo.OptimizedModule` 包装器将被移除,并在模型被策略设置好后重新应用
    same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
    相同的设置(例如,模型被 DDP、FSDP 等包装之后)。如果编译 DDP/FSDP 造成问题,设置为``False``。

Returns:
返回:

The tuple containing wrapped module and the optimizers, in the same order they were passed in.
一个包含包装的模块和优化器的元组,顺序与传入时相同。

        """

tokenizer

 

    def encode(
        self,
        string: str,
        bos: bool = True,
        eos: bool = False,
        max_length: int = -1,
        pad: bool = False,
        device: Optional[torch.device] = None
    ) -> torch.Tensor:
        tokens = self.processor.encode(string)
        if bos:
            tokens = [self.bos_id] + tokens
        if eos:
            tokens = tokens + [self.eos_id]
        if max_length > 0:
            tokens = tokens[:max_length]
        if pad and len(tokens) < max_length:
            tokens += [self.pad_id] * (max_length - len(tokens))

        return torch.tensor(tokens, dtype=torch.int, device=device)

    def decode(self, tokens: torch.Tensor) -> str:
        return self.processor.decode(tokens.tolist())

 generate()

@torch.no_grad()
def generate(
    model: LLaMA,
    idx: torch.Tensor,
    max_new_tokens: int,
    *,
    max_seq_length: Optional[int] = None,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    eos_id: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
接收一个条件序列(提示)作为输入,并继续生成所请求的数量的标记。

The implementation of this function is modified from A. Karpathy's nanoGPT.
此函数的实现改编自 A. Karpathy 的 nanoGPT。

Args:
参数:

model: The model to use.
model: 要使用的模型。

idx: Tensor of shape (T) with indices of the prompt sequence.
idx: 形状为 (T) 的张量,其中包含提示序列的索引。

max_new_tokens: The number of new tokens to generate.
max_new_tokens: 要生成的新分词数量。

max_seq_length: The maximum sequence length allowed.
max_seq_length: 允许的最大序列长度。

temperature: Scales the predicted logits by 1 / temperature
temperature: 通过 1 / temperature 对预测的 logits 进行缩放。

top_k: If specified, only sample among the tokens with the k highest probabilities
top_k: 如果指定,只从概率最高的 k 个标记中进行采样。

eos_id: If specified, stop generating any more token once the <eos> token is triggered
eos_id: 如果指定,一旦触发 <eos> 标记,停止生成更多标记。

    """

 https://pytorch.ac.cn/xla/release/2.1/index.htmlXLA 设备上的 PyTorch

model

    def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
        return build_rope_cache(
            seq_len=self.config.block_size,
            n_elem=self.config.n_embd // self.config.n_head,
            dtype=idx.dtype,
            device=idx.device,
        )

temperature

温度越低,结果的差距越大,会使概率分布更加尖锐,从而使得模型更倾向于选择最高概率的类别。

topk()  

def topk(input: Tensor, k: Union[_int, SymInt], dim: _int = -1, largest: _bool = True, sorted: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.topk: 
    r"""
    topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)

返回给定 input 张量在指定维度上最大的 k 个元素。

如果没有给定 dim,则选择 input 张量的最后一个维度。

如果 largest 设置为 False,则返回 k 个最小元素。

函数返回一个命名元组 (values, indices),其中 values 和 indices 分别是输入张量在指定维度 dim 上最大的 k 个元素及其索引。

布尔选项 sorted 如果为 True,则确保返回的 k 个元素按顺序排列。

参数:

input (Tensor): 输入张量。
k (int): "top-k" 中的 k 值。
dim (int, optional): 排序的维度。
largest (bool, optional): 控制是否返回最大还是最小元素。
sorted (bool, optional): 控制是否返回排序后的元素。
关键字参数:

out (tuple, optional): 可选的输出元组 (Tensor, LongTensor),可以作为输出缓冲区使用。
示例:

python
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
    """

torch.multinomial

def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: 
    r"""
    def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor:
    r"""
    multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor
    
    返回一个张量,其中每一行包含 :attr:`num_samples` 个从对应行的多项分布中采样的索引。
    更严格地说,是从多元分布中采样,更多细节请参考 torch.distributions.multinomial.Multinomial。
    
    .. note::
        :attr:`input` 的行不需要和为 1(在这种情况下,我们使用值作为权重),
        但必须是非负的、有穷的,并且和不为零。
    
    索引按从左到右的顺序排列,依据每个索引被采样的顺序(第一个样本放在第一列)。
    
    如果 :attr:`input` 是一个向量,:attr:`out` 是一个大小为 :attr:`num_samples` 的向量。
    
    如果 :attr:`input` 是一个有 `m` 行的矩阵,则 :attr:`out` 是一个形状为
    :math:`(m \times \text{num\_samples})` 的矩阵。
    
    如果 `replacement` 为 ``True``,则样本是有放回的。
    
    如果不是,则样本是无放回的,这意味着一旦为某行绘制了一个样本索引,
    在该行中不能再次绘制相同的索引。
    
    .. note::
        当无放回采样时,:attr:`num_samples` 必须小于 :attr:`input` 中非零元素的数量
        (如果 `input` 是矩阵,则为每行的非零元素的最小数量)。
    
    Args:
        input (Tensor): 包含概率的输入张量
        num_samples (int): 要绘制的样本数量
        replacement (bool, optional): 是否允许重复抽样
    
    关键字参数:
        generator (:class:`torch.Generator`, optional): 用于采样的伪随机数生成器
        out (Tensor, optional): 输出张量。
    
    示例::
    
        >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # 创建一个权重张量
        >>> torch.multinomial(weights, 2)
        tensor([1, 2])
        >>> torch.multinomial(weights, 4) # 错误!
        RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
        not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
        >>> torch.multinomial(weights, 4, replacement=True)
        tensor([ 2,  1,  1,  1])
    """
    """

 model.reset_cache()

 

Pytorch清空显存缓冲区(torch.cuda.empty_cache)_pytorch 清空显存-CSDN博客 
Pytorch 如何在使用模型后清除GPU内存|极客教程


http://www.kler.cn/a/289041.html

相关文章:

  • 海康工业相机的应用部署不是简简单单!?
  • Spring的IoC、Bean、DI的简单实现,难度:※※※
  • 项目开发实践——基于SpringBoot+Vue3实现的在线考试系统(七)
  • 【2024年华为OD机试】(C卷,100分)- 悄悄话 (Java JS PythonC/C++)
  • Java——Stream流的peek方法详解
  • 【网络协议】RFC3164-The BSD syslog Protocol
  • 【C++ 面试 - STL】每日 3 题(五)
  • 解读GaussianTalker:利用音频驱动的基于3D高斯点染技术的实时高保真讲话头像合成
  • Idea_服务器自动化部署_傻瓜式教程
  • MySQL中的分组统计
  • 云计算环境下的数据治理
  • 学习之git
  • 算法设计:实验二贪心算法
  • wget下载速度受到哪些因素影响?
  • MySQL:简述多版本并发控制MVCC
  • 无人机之电池篇
  • Python与R的完美协作:深入解析subprocess模块调用R脚本的参数传递机制
  • 安装WMware和Ubuntu并使用xShell连接
  • Map排序与转换的深入探索:Java与Kotlin的实现与应用
  • 宝兰德多款仓颉开源项目获GitCode官方G-Star毕业认证,释放开发效率新动能
  • 将军百战死,程序十年成
  • Spring Cloud Eureka与Kubernetes的集成:服务发现的混合方案
  • YOLO-World: Real-Time Open-Vocabulary Object Detection:实时开放词汇对象检测
  • QT教程-十七,QTextBrowser
  • dnsperf测试dns性能
  • 春秋云镜initial