当前位置:   article > 正文

迁移学习之TCA算法_tca迁移学习

tca迁移学习

TCA(迁移成分分析):
TCA算法解决的问题:和PCA算法有点像,可以实现降维,两个高维的大矩阵(源域和目标域矩阵)进去,得到两个低维的矩阵(降维后的源域和目标域矩阵)。TCA可以将分布不同的源域和目标域数据映射到高维再生核希尔伯特空间中,然后不断缩小源域和目标域的距离并最大程度的保留其内部属性。

TCA如何进行迁移:寻找一个特征映射,使得映射后的数据分布源域和目标域概率密度相等,并且条件概率密度也相等。由于迁移学习的本质是最小化源域和目标域的距离,因此TCA利用到了MMD(最大均值差异)算法来衡量源域和目标域的距离

TCA步骤:首先计算L(MMD引入的矩阵)和H(中心矩阵)矩阵,然后选择常用的核函数进行映射求得K(核矩阵),然后求 ( K L K + μ I ) − 1 K H K (KLK+\mu I)^{-1}KHK (KLK+μI)1KHK的前m个特征值。

数据集来自https://pan.baidu.com/s/1bp4g7Av,加载数据并进行简单的归一化

load (’Caltech. mat ’) ; % source domain
fts = fts./repmat(sum(fts,2),1,size(fts,2)) ;
Xs = zscore (fts ,1) ;clearfts
Ys = labels ;clearlabels
load (’amazon.mat’) ; % targ et domain
fts = fts./repmat(sum(fts,2),1,size(fts,2)) ;
Xt = zscore (fts ,1) ;clearfts
Yt = labels ;clearlabels
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

TCA算法matlab代码:
TCA.m

function [ X_src_new , X_tar_new ,A] = TCA(X_src , X_tar , options )
lambda = options.lambda ;
dim = options.dim ;
kernel_type = options.kernel_type ;
gamma = options.gamma;
X = [ X_src' , X_tar'] ;
X = X*diag ( sparse ( 1./sqrt(sum(X.^2 ) ) ) ) ;
[m, n ] = size(X) ;
ns = size(X_src , 1 ) ;
nt = size(X_tar,1);
e = [1/ ns*ones(ns,1);-1/nt*ones(nt,1)];
M = e*e' ;
M = M / norm(M,'fro') ;
H = eye (n)-1/(n)*ones (n,n);
if strcmp ( kernel_type ,'primal')
[A, ~ ] = eigs(X*M*X'+lambda*eye (m) ,X*H*X' , dim , 'SM') ;
Z = A'*X;
Z = Z*diag ( sparse ( 1./sqrt (sum(Z.^2 ) ) ) ) ;
X_src_new = Z ( : , 1 : ns )';
X_tar_new = Z ( : , ns+1:end )';
else
K = TCA_kernel( kernel_type ,X, [ ] , gamma) ;
[A, ~ ] = eigs(K*M*K'+lambda*eye(n) ,K*H*K' ,dim,'SM') ;
Z = A'*K;
Z = Z*diag ( sparse (1./sqrt(sum(Z.^2)))) ;
X_src_new = Z(:,1:ns)' ;
X_tar_new = Z(:,ns+1:end)';
end
end

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

TCA_kernel.m

function K = TCA_kernel( ker ,X, X2,gamma)
switch ker
case 'linear'
if isempty (X2)
K = X'*X;
else
K = X'*X2;
end
case 'rbf'
n1sq = sum(X.^2,1);
n1 = size(X,2);
if isempty (X2)
D = ( ones (n1 , 1 ) *n1sq )' + ones (n1 , 1 )*n1sq -2*X'*X;
else
n2sq = sum(X2.^2 , 1 ) ;
n2 = size(X2, 2 );
D = ( ones (n2 , 1 ) *n1sq )'+ ones (n1 , 1 )*n2sq -2*X'*X2;
end
K = exp(-gamma*D) ;
case 'sam'
if isempty (X2)
D = X'*X;
else
D = X'*X2;
end
K = exp(-gamma*acos (D).^2 ) ;
otherwise
error( ['Unsupported kernel' ker ] )
end
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

test.m

function [acc,X_src_new,X_tar_new] = test() 
options.gamma=2;
options.kernel_type='linear';
options.lambda=1.0;
options.dim=20;
[X_src_new,X_tar_new,A]=TCA(Xs,Xt,options);
knn_model=fitcknn(X_src_new,Ys,'NumNeighbors',1);
Y_tar_pesudo=knn_model.predict(X_tar_new);
acc=length(find(Y_tar_pesudo==Yt))/length(Yt);
fprintf('Acc=%0.4f\n',acc);
end

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/爱喝兽奶帝天荒/article/detail/916928
推荐阅读
相关标签
  

闽ICP备14008679号