Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/random_sample_batched.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
Expand Down
12 changes: 12 additions & 0 deletions include/infinicore/ops/flash_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool);

Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/kv_caching.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(KVCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);

void kv_caching_(Tensor k_cache,
Tensor v_cache,
const Tensor &k,
const Tensor &v,
const Tensor &past_kv_lengths);
} // namespace infinicore::op
20 changes: 20 additions & 0 deletions include/infinicore/ops/random_sample_batched.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class RandomSampleBatched {
public:
using schema = void (*)(Tensor, Tensor, const float *, const float *, const int *, const float *, int);
static void execute(Tensor result, Tensor probs, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
static common::OpDispatcher<schema> &dispatcher();
};

// Out-of-place API
Tensor random_sample_batched(Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
// In-place API
void random_sample_batched_(Tensor indices, Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);

} // namespace infinicore::op
3 changes: 3 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/kv_caching.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/lp_norm.h"
Expand All @@ -20,6 +22,7 @@
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/random_sample_batched.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
Expand Down
36 changes: 36 additions & 0 deletions include/infiniop/ops/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
#define __INFINIOP_FLASH_ATTENTION_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;

__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
float scale,
char is_causal);

__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream);

__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
infiniopFlashAttentionDescriptor_t desc);
#endif
31 changes: 31 additions & 0 deletions include/infiniop/ops/kv_caching.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef __INFINIOP_KV_CACHING_API_H__
#define __INFINIOP_KV_CACHING_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t;

__C __export infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle,
infiniopKVCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths);

__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream);

__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc);

#endif
6 changes: 0 additions & 6 deletions include/infiniop/ops/random_sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
infiniopRandomSampleDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor(
infiniopHandle_t handle,
infiniopRandomSampleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs);

__C __export infiniStatus_t infiniopRandomSample(
infiniopRandomSampleDescriptor_t desc,
void *workspace,
Expand Down
34 changes: 34 additions & 0 deletions include/infiniop/ops/random_sample_batched.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__
#define __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopRandomSampleBatchedDescriptor_t;

__C __export infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor(
infiniopHandle_t handle,
infiniopRandomSampleBatchedDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs);

__C __export infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize(
infiniopRandomSampleBatchedDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopRandomSampleBatched(
infiniopRandomSampleBatchedDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
const float *random_val,
const float *topp,
const int *topk,
const float *temperature,
int batch_size,
void *stream);

__C __export infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor(
infiniopRandomSampleBatchedDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
from infinicore.ops.attention import attention
from infinicore.ops.kv_caching import kv_caching
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
Expand Down Expand Up @@ -115,6 +116,7 @@
"add_rms_norm",
"add_rms_norm_",
"attention",
"kv_caching",
"matmul",
"mul",
"narrow",
Expand Down
10 changes: 7 additions & 3 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from .causal_softmax import causal_softmax
from .embedding import embedding
from .flash_attention import flash_attention
from .linear import linear
from .random_sample import random_sample
from .rms_norm import rms_norm
from .rope import RopeAlgo, rope
from .scaled_dot_product_attention import scaled_dot_product_attention
from .silu import silu
from .swiglu import swiglu

__all__ = [
"causal_softmax",
"embedding",
"flash_attention",
"linear",
"random_sample",
"rms_norm",
"rope",
"scaled_dot_product_attention",
"silu",
"swiglu",
"linear",
"embedding",
"rope",
"RopeAlgo",
]
34 changes: 34 additions & 0 deletions python/infinicore/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import math

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def flash_attention(
query,
key,
value,
total_kv_len,
attn_mask=None,
dropout_p=0,
is_causal=False,
scale=None,
enable_gqa=False,
):
assert attn_mask is None and dropout_p == 0 and not enable_gqa

emb_dim = query.shape[-1]

if scale is None:
scale = 1 / math.sqrt(emb_dim)

return Tensor(
_infinicore.flash_attention(
query._underlying,
key._underlying,
value._underlying,
total_kv_len._underlying,
scale,
is_causal,
)
)
35 changes: 35 additions & 0 deletions python/infinicore/nn/functional/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import math

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0,
is_causal=False,
scale=None,
enable_gqa=False,
):
raise NotImplementedError("Scaled Dot Product Attention is not yet supported.")

assert attn_mask is None and dropout_p == 0 and not enable_gqa

emb_dim = query.shape[-1]

if scale is None:
scale = 1 / math.sqrt(emb_dim)

return Tensor(
_infinicore.flash_attention(
query._underlying,
key._underlying,
value._underlying,
key.shape[-2],
scale,
is_causal,
)
)
13 changes: 13 additions & 0 deletions python/infinicore/ops/kv_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from infinicore.lib import _infinicore


def kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
_infinicore.kv_caching_(
k_cache._underlying,
v_cache._underlying,
k._underlying,
v._underlying,
past_kv_lengths._underlying,
)

return k_cache, v_cache
33 changes: 25 additions & 8 deletions scripts/build_ntops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import importlib
import pathlib

Expand All @@ -11,16 +12,32 @@
def _find_and_build_ops():
ops_path = SRC_DIR_PATH / "infiniop" / "ops"

for op_dir in ops_path.iterdir():
ninetoothed_path = op_dir / "ninetoothed"
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = []

if ninetoothed_path.is_dir():
module_path = ninetoothed_path / "build"
relative_path = module_path.relative_to(SRC_DIR_PATH)
import_name = ".".join(relative_path.parts)
module = importlib.import_module(import_name)
for op_dir in ops_path.iterdir():
ninetoothed_path = op_dir / "ninetoothed"

module.build()
if not ninetoothed_path.is_dir():
continue

build_file = ninetoothed_path / "build.py"
if not build_file.exists():
continue

futures.append(executor.submit(_build, ninetoothed_path))

for future in concurrent.futures.as_completed(futures):
future.result()


def _build(ninetoothed_path):
module_path = ninetoothed_path / "build"
relative_path = module_path.relative_to(SRC_DIR_PATH)
import_name = ".".join(relative_path.parts)
module = importlib.import_module(import_name)

module.build()


if __name__ == "__main__":
Expand Down
Loading