Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "ggml-cuda/pool2d.cuh"
#include "ggml-cuda/quantize.cuh"
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/roll.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
Expand Down Expand Up @@ -2419,6 +2420,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ROPE_BACK:
ggml_cuda_op_rope_back(ctx, dst);
break;
case GGML_OP_ROLL:
ggml_cuda_op_roll(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
Expand Down Expand Up @@ -3411,6 +3415,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
return max_bias == 0.0f;
}
case GGML_OP_ROLL:
if(op->src[0]->type == GGML_TYPE_F32) {
return true;
}
return false;
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK: {
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
Expand Down
67 changes: 67 additions & 0 deletions ggml/src/ggml-cuda/roll.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "ggml-cuda/common.cuh"
#include "roll.cuh"

static __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) {
if (idx < 0) {
return idx + ne;
}
if (idx >= ne) {
return idx - ne;
}
return idx;
}

static __global__ void roll_f32_cuda(const float * __restrict__ src,
float * __restrict__ dst,
const int64_t ne00,
const int64_t ne01,
const int64_t ne02,
const int64_t ne03,
const int s0,
const int s1,
const int s2,
const int s3) {
const int64_t idx = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t n_elements = ne00 * ne01 * ne02 * ne03;

if (idx >= n_elements) {
return;
}

const int64_t i0 = idx % ne00;
const int64_t i1 = (idx / ne00) % ne01;
const int64_t i2 = (idx / (ne00 * ne01)) % ne02;
const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03;

const int64_t d0 = wrap_index(i0 - s0, ne00);
const int64_t d1 = wrap_index(i1 - s1, ne01);
const int64_t d2 = wrap_index(i2 - s2, ne02);
const int64_t d3 = wrap_index(i3 - s3, ne03);

dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] =
src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0];
}

void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
int s0 = dst->op_params[0];
int s1 = dst->op_params[1];
int s2 = dst->op_params[2];
int s3 = dst->op_params[3];

const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) dst->src[0]->data;
float * dst_d = (float *) dst->data;

GGML_TENSOR_UNARY_OP_LOCALS;

GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst));

cudaStream_t stream = ctx.stream();

int64_t sz = (ne00 * ne01 * ne02 * ne03);
int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE;

roll_f32_cuda<<<num_blocks, CUDA_ROLL_BLOCK_SIZE, 0, stream>>>(
src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3);
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/roll.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_ROLL_BLOCK_SIZE 256

void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
38 changes: 25 additions & 13 deletions scripts/server-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
return ret


def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]:
assert n_prompts >= 0
ret: list[int] = []
for i in range(n_prompts):
random.seed(13 * i + 0)
if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 0)
ret.append(random.randint(prompt_length_min, prompt_length_max))
return ret

Expand All @@ -46,12 +47,20 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:


def get_server(path_server: str, path_log: Optional[str]) -> dict:
logger.info("Starting the llama.cpp server...")
hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
if os.environ.get("LLAMA_ARG_HOST") is None:
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
if os.environ.get("LLAMA_ARG_PORT") is None:
logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
os.environ["LLAMA_ARG_PORT"] = "8080"
hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
assert hostname is not None
assert port is not None
address: str = f"http://{hostname}:{port}"
logger.info(f"Starting the llama.cpp server under {address}...")

fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)

n_failures: int = 0
Expand All @@ -60,7 +69,7 @@ def get_server(path_server: str, path_log: Optional[str]) -> dict:
sleep(1.0)
exit_code = process.poll()
if exit_code is not None:
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
response = requests.get(f"{address}/health")
if response.status_code == 200:
break
Expand Down Expand Up @@ -128,7 +137,7 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
return (t_submit, token_arrival_times)


def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int):
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
Expand All @@ -139,7 +148,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"

parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
synthetic_prompts: bool = prompts is None
prompt_n = []
Expand All @@ -151,7 +160,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
prompt_length_min: int = int(prompt_source_split[1])
prompt_length_max: int = int(prompt_source_split[2])
logger.info("Generating random prompts...")
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset)
prompts = get_prompts_rng(prompt_n)
else:
n_predict_min = n_predict
Expand All @@ -176,10 +185,11 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
data: list[dict] = []

for i, p in enumerate(prompts):
random.seed(13 * i + 1)
if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 1)
data.append({
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
"n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
"n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})

if not synthetic_prompts:
logger.info("Getting the prompt lengths...")
Expand Down Expand Up @@ -251,7 +261,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
"Results are printed to console and visualized as plots (saved to current working directory). "
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
parser.add_argument(
"--prompt_source", type=str, default="rng-1024-2048",
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
Expand All @@ -261,5 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
parser.add_argument(
"--n_predict_min", type=int, default=1024,
help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. "
"Corelations between seeds can occur when set >= 1000. Negative values mean no seed.")
args = parser.parse_args()
benchmark(**vars(args))
Loading