/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of MT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package mt;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;

/**
 * Diagonal matrix
 */
public class DiagMatrix extends AbstractMatrix implements Serializable {

	private static final long serialVersionUID = -996829843451435828L;

	/**
	 * Diagonal storage
	 */
	private double[] data;

	/**
	 * Size of the matrix
	 */
	private int n;

	/**
	 * Constructor for DiagMatrix
	 * 
	 * @param n
	 *            Size of the matrix. Since the matrix must be square, this
	 *            equals both the number of rows and columns
	 */
	public DiagMatrix(int n) {
		super(n, n);

		this.n = n;
		data = new double[n];
	}

	public DiagMatrix(double[] diag) {
		super(diag.length, diag.length);
		this.n = diag.length;
		this.data = diag;
	}

	/**
	 * Constructor for DiagMatrix
	 * 
	 * @param A
	 *            Matrix to copy from. Only the diagonal part is copied
	 */
	public DiagMatrix(Matrix A) {
		this(A, true);
	}

	/**
	 * Constructor for DiagMatrix
	 * 
	 * @param A
	 *            Matrix to copy from. Only the diagonal part is copied
	 * @param deep
	 *            True for a deep copy, else it's shallow. For shallow copies,
	 *            <code>A</code> must be a <code>DiagMatrix</code>
	 */
	public DiagMatrix(Matrix A, boolean deep) {
		super(A);

		if (!isSquare())
			throw new IllegalArgumentException("Diagonal matrix must be square");
		n = numRows;

		if (deep) {
			data = new double[n];
			/*for (MatrixEntry e : A)
				if (e.row() == e.column())
					set(e.row(), e.column(), e.get());*/
			MatrixEntry e;
			//Iterator<MatrixEntry> iter = A.iterator();
			Iterator iter = A.iterator();
			while(iter.hasNext()) {
				e = (MatrixEntry) iter.next();
				if (e.row() == e.column())
					set(e.row(), e.column(), e.get());
			}
		} else {
			DiagMatrix B = (DiagMatrix) A;
			data = B.getDiagonal();
		}
	}

	/**
	 * Returns the diagonal entries
	 */
	public double[] getDiagonal() {
		return data;
	}

	public void add(int row, int column, double value) {
		check(row, column);
		if (row == column)
			data[row] += value;
		else
			throw new IllegalArgumentException("Insertion index outside of diagonal");
	}

	public double get(int row, int column) {
		check(row, column);
		if (row == column)
			return data[row];
		else
			return 0;
	}

	public void set(int row, int column, double value) {
		check(row, column);
		if (row == column)
			data[row] = value;
		else
			throw new IllegalArgumentException("Insertion index outside of diagonal");
	}

	//public DiagMatrix copy() {
	public Matrix copy() {
		return new DiagMatrix(this);
	}

	//public DiagMatrix zero() {
	public Matrix zero() {
		Arrays.fill(data, 0);
		return this;
	}

	public Matrix transpose() {
		return this;
	}

	public Vector multAdd(
		double alpha,
		Vector x,
		double beta,
		Vector y,
		Vector z) {
		if (!(x instanceof DenseVector) || !(z instanceof DenseVector))
			return super.multAdd(alpha, x, beta, y, z);

		checkMultAdd(x, y, z);

		double[] xd = ((DenseVector) x).getData(),
			zd = ((DenseVector) z).getData();

		z.set(beta, y);

		if (alpha != 0)
			for (int i = 0; i < n; ++i)
				zd[i] += alpha * data[i] * xd[i];

		return z;
	}

	public Vector transMultAdd(
		double alpha,
		Vector x,
		double beta,
		Vector y,
		Vector z) {
		return multAdd(alpha, x, beta, y, z);
	}

	public Matrix multAdd(
		double alpha,
		Matrix B,
		double beta,
		Matrix C,
		Matrix D) {
		checkMultAdd(B, C, D);

		D.set(beta, C);

		if (alpha != 0.)
			for (int i = 0; i < n; ++i)
				for (int j = 0; j < C.numColumns(); ++j)
					D.add(i, j, alpha * data[i] * B.get(i, j));

		return D;
	}

	public Matrix transABmultAdd(
		double alpha,
		Matrix B,
		double beta,
		Matrix C,
		Matrix D) {
		return transBmultAdd(alpha, B, beta, C, D);
	}

	public Matrix transAmultAdd(
		double alpha,
		Matrix B,
		double beta,
		Matrix C,
		Matrix D) {
		return multAdd(alpha, B, beta, C, D);
	}

	public Matrix transBmultAdd(
		double alpha,
		Matrix B,
		double beta,
		Matrix C,
		Matrix D) {
		checkTransBmultAdd(B, C, D);

		D.set(beta, C);

		if (alpha != 0)
			for (int i = 0; i < n; ++i)
				for (int j = 0; j < C.numColumns(); ++j)
					D.add(i, j, alpha * data[i] * B.get(j, i));

		return D;
	}

	public Matrix solve(Matrix B, Matrix X) {
		checkSolve(B, X);

		for (int i = 0; i < n; ++i)
			for (int j = 0; j < n; ++j)
				X.set(i, j, B.get(i, j) / data[i]);

		return X;
	}

	public Vector solve(Vector b, Vector x) {
		checkSolve(b, x);

		for (int i = 0; i < n; ++i)
			x.set(i, b.get(i) / data[i]);

		return x;
	}

	public Matrix transSolve(Matrix B, Matrix X) {
		return solve(B, X);
	}

	public Vector transSolve(Vector b, Vector x) {
		return solve(b, x);
	}

	//public Iterator < MatrixEntry > iterator() {
	public Iterator iterator() {
		return new DiagMatrixIterator();
	}

	/**
	 * Iterator over a diagonal matrix
	 */
	private class DiagMatrixIterator extends RefMatrixIterator {

		protected void nextPosition() {
			rowNext++;
			columnNext++;
		}

	}

}
