summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/matrix.rb29
-rw-r--r--test/matrix/test_matrix.rb17
2 files changed, 46 insertions, 0 deletions
diff --git a/lib/matrix.rb b/lib/matrix.rb
index 0a2b609b59..2f2d0f371a 100644
--- a/lib/matrix.rb
+++ b/lib/matrix.rb
@@ -250,6 +250,35 @@ class Matrix
end
#
+ # Create a matrix by combining matrices entrywise, using the given block
+ #
+ # x = Matrix[[6, 6], [4, 4]]
+ # y = Matrix[[1, 2], [3, 4]]
+ # Matrix.combine(x, y) {|a, b| a - b} # => Matrix[[5, 4], [1, 2]]
+ #
+ def Matrix.combine(*matrices)
+ return to_enum(__method__, *matrices) unless block_given?
+
+ return Matrix.empty if matrices.empty?
+ matrices.map!(&CoercionHelper.method(:coerce_to_matrix))
+ x = matrices.first
+ matrices.each do |m|
+ Matrix.Raise ErrDimensionMismatch unless x.row_count == m.row_count && x.column_count == m.column_count
+ end
+
+ rows = Array.new(x.row_count) do |i|
+ Array.new(x.column_count) do |j|
+ yield matrices.map{|m| m[i,j]}
+ end
+ end
+ new rows, x.column_count
+ end
+
+ def combine(*matrices, &block)
+ Matrix.combine(self, *matrices, &block)
+ end
+
+ #
# Matrix.new is private; use Matrix.rows, columns, [], etc... to create.
#
def initialize(rows, column_count = rows[0].size)
diff --git a/test/matrix/test_matrix.rb b/test/matrix/test_matrix.rb
index 5b0bc968f7..92a24d1e9f 100644
--- a/test/matrix/test_matrix.rb
+++ b/test/matrix/test_matrix.rb
@@ -592,6 +592,23 @@ class TestMatrix < Test::Unit::TestCase
assert_equal Matrix[[1],[2],[3]], Matrix.vstack(Vector[1,2], Vector[3])
end
+ def test_combine
+ x = Matrix[[6, 6], [4, 4]]
+ y = Matrix[[1, 2], [3, 4]]
+ assert_equal Matrix[[5, 4], [1, 0]], Matrix.combine(x, y) {|a, b| a - b}
+ assert_equal Matrix[[5, 4], [1, 0]], x.combine(y) {|a, b| a - b}
+ # Without block
+ assert_equal Matrix[[5, 4], [1, 0]], Matrix.combine(x, y).each {|a, b| a - b}
+ # With vectors
+ assert_equal Matrix[[111], [222]], Matrix.combine(Matrix[[1], [2]], Vector[10,20], Vector[100,200], &:sum)
+ # Basic checks
+ assert_raise(Matrix::ErrDimensionMismatch) { @m1.combine(x) { raise } }
+ # Edge cases
+ assert_equal Matrix.empty, Matrix.combine{ raise }
+ assert_equal Matrix.empty(3,0), Matrix.combine(Matrix.empty(3,0), Matrix.empty(3,0)) { raise }
+ assert_equal Matrix.empty(0,3), Matrix.combine(Matrix.empty(0,3), Matrix.empty(0,3)) { raise }
+ end
+
def test_eigenvalues_and_eigenvectors_symmetric
m = Matrix[
[8, 1],