Re: [問題] 關於這段程式碼 要如何優化?
※ 引述《jackhzt (巴克球)》之銘言:
: [問題類型]:
: 效能諮詢(我想讓R 跑更快)
: [軟體熟悉度]:
: 使用者(已經有用R 做過不少作品)
: [問題敘述]:如何將以下的程式碼跑快一點
: [程式範例]:R中的dist這function 因為想使用不同的計算方式,所以希望可以做到
: 和此function表現差不多的function
: 程式碼可貼於以下網站:
: https://gist.github.com/anonymous/cf844933bb6858936e25
: 希望有提高效率的方法
1. 如果距離函數是常見的,通常建議用dist達成
2. 如果不常見,盡量考慮用矩陣運算求出來
EX: (以歐式距離來說)
library(magrittr)
x <- matrix(rnorm(50), 5, 10)
distMat <- sweep(-x %*% t(x) * 2, 2, rowSums(x^2), '+') %>%
sweep(1, rowSums(x^2), '+')
diag(distMat) <- 0
distMat %<>% sqrt
all.equal(distMat, as.matrix(dist(x)), check.attributes = FALSE) # TRUE
3. 用簡單的平行,我可能會這樣做:
library(magrittr)
library(foreach)
library(doSNOW)
library(plyr)
library(Matrix)
dis <- function(x, y) sum(abs(as.numeric(x)-as.numeric(y)))
x <- matrix(rnorm(50), 5, 10)
allCombinations <- combn(1:nrow(x), 2)
cl <- makeCluster(8, type = "SOCK")
registerDoSNOW(cl)
clusterExport(cl, list = c("dis", "x"))
res <- aaply(allCombinations, 2, function(v){
dis(x[v[1],], x[v[2],])
}, .parallel = TRUE)
stopCluster(cl)
distMat <- sparseMatrix(i = allCombinations[1,],
j = allCombinations[2,], x = res) %>%
rbind(0) %>% as.matrix
distMat[lower.tri(distMat)] <- res
4. 或是乾脆用RcppArmadillo:
(下面是用之前kernel matrix估計的方法,之前有發過更快的方法...)
library(Rcpp)
library(RcppArmadillo)
## For windows user
# library(inline)
# settings <- getPlugin("Rcpp")
# settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS)
# settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS)
# do.call(Sys.setenv, settings$env)
sourceCpp(code = '
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
#include <omp.h>
// [[Rcpp::plugins(openmp)]]
using namespace Rcpp;
using namespace arma;
// [[Rcpp::export]]
NumericMatrix kernelMatrix_cpp(NumericMatrix Xr, NumericMatrix Centerr,
double sigma) {
omp_set_num_threads(omp_get_max_threads());
uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index;
mat X(Xr.begin(), n, Xr.ncol(), false);
mat Center(Centerr.begin(), b, Centerr.ncol(), false);
mat KerX(n, b);
#pragma omp parallel private(row_index, col_index)
for (row_index = 0; row_index < n; row_index++)
{
#pragma omp for nowait
for (col_index = 0; col_index < b; col_index++)
{
KerX(row_index, col_index) = exp(sum(square(X.row(row_index)
- Center.row(col_index))) / (-2.0 * sigma * sigma));
}
}
return wrap(KerX);
}')
--
R資料整理套件系列文:
magrittr #1LhSWhpH (R_Language) http://tinyurl.com/1LhSWhpH
data.table #1LhW7Tvj (R_Language) http://tinyurl.com/1LhW7Tvj
dplyr(上) #1LhpJCfB (R_Language) http://tinyurl.com/1LhpJCfB
dplyr(下) #1Lhw8b-s (R_Language)
tidyr #1Liqls1R (R_Language) http://tinyurl.com/1Liqls1R
--
※ 發信站: 批踢踢實業坊(ptt.cc), 來自: 140.109.73.238
※ 文章網址: https://www.ptt.cc/bbs/R_Language/M.1457718749.A.B23.html
※ 編輯: celestialgod (140.109.73.238), 03/12/2016 01:55:02
推
03/12 02:03, , 1F
03/12 02:03, 1F
推
03/12 03:27, , 2F
03/12 03:27, 2F
→
03/12 03:27, , 3F
03/12 03:27, 3F
正常
推
03/12 03:42, , 4F
03/12 03:42, 4F
對
→
03/12 03:43, , 5F
03/12 03:43, 5F
→
03/12 03:44, , 6F
03/12 03:44, 6F
沒有
→
03/12 03:45, , 7F
03/12 03:45, 7F
→
03/12 03:46, , 8F
03/12 03:46, 8F
可以回文繼續問
推
03/12 17:43, , 9F
03/12 17:43, 9F
po一下你的計算公式,我測試看看
推
03/12 18:24, , 10F
03/12 18:24, 10F
→
03/12 18:24, , 11F
03/12 18:24, 11F
→
03/12 18:25, , 12F
03/12 18:25, 12F
→
03/12 18:26, , 13F
03/12 18:26, 13F
請問什麼是R17?
推
03/12 19:34, , 14F
03/12 19:34, 14F
→
03/12 19:35, , 15F
03/12 19:35, 15F
推
03/12 20:03, , 16F
03/12 20:03, 16F
不客氣,簡單平行還太慢就得想辦法(攤手
※ 編輯: celestialgod (180.218.152.118), 03/12/2016 20:11:27
討論串 (同標題文章)
本文引述了以下文章的的內容:
完整討論串 (本文為第 2 之 2 篇):
R_Language 近期熱門文章
PTT數位生活區 即時熱門文章
14
26