[心得] 用 Matlab 寫 MEX 函數加速 vol.4
最後一篇, 這篇會提出一個簡單的 MEX file 的 Framework
廢話不多說, 直接看 example, 程式碼才是最好的說明書
這個 example 實際上是 vol.2 和 vol.3 的合體
如果前面的程式都看懂了, 這個程式應該不會有任何困難
建議把 code 複製到 Matlab 的 Editor 來看, 比較不會傷眼睛 :p
範例程式接受一個 2-D 的 input, 把每個元素 +1 以後輸出
output = input + 1;
在程式碼裡, 我把程式區分了很多塊
以一個標準的 MEX 檔來說, 大概會有這幾個部分
* 取得輸入參數的資料 (eg. dimension)
* 配置輸出參數 (Output Allocation)
* 根據輸入, 計算輸出 (Data Processing)
把下面的程式看懂, 碰到要寫 MEX 的時候
直接套下面這個範例, 應該可以省不少時間 :)
#include "mex.h"
#include <math.h>
#include <stdio.h>
// Program framework for input, output and processing
// usage:
// input = zeros(2, 10)
// input(1, :) = 1:10
// input(2, :) = 10:-1:1
// mex test3.c
// a = test3(input)
// note: type all commands above in Matlab Command Window
void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[])
{
int i, j, k;
// input
int input_dim_x;
int input_dim_y;
double *in; // pointer to process content of the input
// output
int output_dimension[2]; // note: you should specify output dimension
int output_dim_x;
int output_dim_y;
double *out;
/* -------------------------------- */
/* NECESSARY (input processing) */
in = mxGetPr(prhs[0]); // get data pointer
// mx: Matrix
// P: Pointer
// r: real
// mxGetPr() has a counterprt mxGetPi();
input_dim_x = mxGetM(prhs[0]);
input_dim_y = mxGetN(prhs[0]);
// print some message about input data
printf("nrhs: %d\n", nrhs);
printf("mxGetM(prhs[0]): %d\n", mxGetM(prhs[0]));
printf("mxGetN(prhs[0]): %d\n", mxGetN(prhs[0]));
// print input content
for(i=0; i<input_dim_x; i++) // x
for(j=0; j<input_dim_y; j++) // y
// notice: data type is "float", you shall use "%f" insted of "%d"
printf("%f\n", in[i + j*input_dim_x]);
/* -------------------------------- */
/* NECESSARY (output allocation) */
// specify output matrix's dimension
output_dim_x = input_dim_x;
output_dim_y = input_dim_y;
output_dimension[0] = output_dim_x; // mxCreateNumericArray() required
output_dimension[1] = output_dim_y;
// Allocate the output matrix
plhs[0] = mxCreateNumericArray(
2,
output_dimension,
mxDOUBLE_CLASS,
mxREAL);
// Get output matrix's data pointer
out = mxGetPr(plhs[0]); // just point at the start
/* -------------------------------- */
/* NECESSARY (data processing) */
// Copy and add 1
for(i=0; i<output_dim_x; i++)
for(j=0; j<output_dim_y; j++)
{
int temp = i + j*input_dim_x;
out[temp] = in[temp]*2;
}
}
--
※ 發信站: 批踢踢實業坊(ptt.cc)
◆ From: 140.113.128.237
推
05/17 11:24, , 1F
05/17 11:24, 1F
推
05/17 23:19, , 2F
05/17 23:19, 2F
→
05/17 23:20, , 3F
05/17 23:20, 3F
MATLAB 近期熱門文章
PTT數位生活區 即時熱門文章